Refactored pyvbmc code
This commit is contained in:
parent
ba4e88d54a
commit
a3e9eb68eb
1 changed files with 396 additions and 162 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
# %%
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# OPTIMISATION PYVBMC POUR MODÈLES Q-LEARNING AVEC ÉVÉNEMENTS RARES
|
# OPTIMISATION PYVBMC POUR MODÈLES Q-LEARNING AVEC ÉVÉNEMENTS RARES
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
@ -14,6 +15,8 @@ from pathlib import Path
|
||||||
# Tentative d'import PyVBMC
|
# Tentative d'import PyVBMC
|
||||||
try:
|
try:
|
||||||
from pyvbmc import VBMC
|
from pyvbmc import VBMC
|
||||||
|
from pyvbmc.priors import UniformBox
|
||||||
|
|
||||||
PYVBMC_AVAILABLE = True
|
PYVBMC_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
PYVBMC_AVAILABLE = False
|
PYVBMC_AVAILABLE = False
|
||||||
|
|
@ -26,6 +29,8 @@ from load_data import all_participant_data, unique_participants
|
||||||
# CONFIGURATIONS DES MODÈLES EMBOÎTÉS
|
# CONFIGURATIONS DES MODÈLES EMBOÎTÉS
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
def get_model_configs() -> Dict:
|
def get_model_configs() -> Dict:
|
||||||
"""Retourne les configurations des différents modèles."""
|
"""Retourne les configurations des différents modèles."""
|
||||||
return {
|
return {
|
||||||
|
|
@ -38,7 +43,7 @@ def get_model_configs() -> Dict:
|
||||||
"n_params": 3,
|
"n_params": 3,
|
||||||
"param_names": ["alpha", "forget", "lambda"],
|
"param_names": ["alpha", "forget", "lambda"],
|
||||||
"lower": np.array([-5, -5, -3]),
|
"lower": np.array([-5, -5, -3]),
|
||||||
"upper": np.array([5, 5, 3])
|
"upper": np.array([5, 5, 3]),
|
||||||
},
|
},
|
||||||
"GAIN_LOSS": {
|
"GAIN_LOSS": {
|
||||||
"name": "GAIN_LOSS",
|
"name": "GAIN_LOSS",
|
||||||
|
|
@ -49,7 +54,7 @@ def get_model_configs() -> Dict:
|
||||||
"n_params": 4,
|
"n_params": 4,
|
||||||
"param_names": ["alpha_loss", "alpha_gain", "forget", "lambda"],
|
"param_names": ["alpha_loss", "alpha_gain", "forget", "lambda"],
|
||||||
"lower": np.array([-5, -5, -5, -3]),
|
"lower": np.array([-5, -5, -5, -3]),
|
||||||
"upper": np.array([5, 5, 5, 3])
|
"upper": np.array([5, 5, 5, 3]),
|
||||||
},
|
},
|
||||||
"BIASED": {
|
"BIASED": {
|
||||||
"name": "BIASED",
|
"name": "BIASED",
|
||||||
|
|
@ -59,12 +64,19 @@ def get_model_configs() -> Dict:
|
||||||
"has_rho": False,
|
"has_rho": False,
|
||||||
"n_params": 10,
|
"n_params": 10,
|
||||||
"param_names": [
|
"param_names": [
|
||||||
"alpha_loss", "alpha_gain",
|
"alpha_loss",
|
||||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
"alpha_gain",
|
||||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4"
|
"forget_1",
|
||||||
|
"forget_2",
|
||||||
|
"forget_3",
|
||||||
|
"forget_4",
|
||||||
|
"lambda_1",
|
||||||
|
"lambda_2",
|
||||||
|
"lambda_3",
|
||||||
|
"lambda_4",
|
||||||
],
|
],
|
||||||
"lower": np.concatenate([[-5, -5], np.full(4, -5), np.full(4, -3)]),
|
"lower": np.concatenate([[-5, -5], np.full(4, -5), np.full(4, -3)]),
|
||||||
"upper": np.concatenate([[5, 5], np.full(4, 5), np.full(4, 3)])
|
"upper": np.concatenate([[5, 5], np.full(4, 5), np.full(4, 3)]),
|
||||||
},
|
},
|
||||||
"REE_BIASED_SIMPLE": {
|
"REE_BIASED_SIMPLE": {
|
||||||
"name": "REE_BIASED_SIMPLE",
|
"name": "REE_BIASED_SIMPLE",
|
||||||
|
|
@ -74,11 +86,15 @@ def get_model_configs() -> Dict:
|
||||||
"has_rho": True,
|
"has_rho": True,
|
||||||
"n_params": 6,
|
"n_params": 6,
|
||||||
"param_names": [
|
"param_names": [
|
||||||
"alpha_loss", "alpha_gain", "forget", "lambda",
|
"alpha_loss",
|
||||||
"rho_BS", "rho_JP"
|
"alpha_gain",
|
||||||
|
"forget",
|
||||||
|
"lambda",
|
||||||
|
"rho_BS",
|
||||||
|
"rho_JP",
|
||||||
],
|
],
|
||||||
"lower": np.array([-5, -5, -5, -3, -10, -10]),
|
"lower": np.array([-5, -5, -5, -3, -10, -10]),
|
||||||
"upper": np.array([5, 5, 5, 3, 10, 10])
|
"upper": np.array([5, 5, 5, 3, 10, 10]),
|
||||||
},
|
},
|
||||||
"REE_BIASED_COMPLEX": {
|
"REE_BIASED_COMPLEX": {
|
||||||
"name": "REE_BIASED_COMPLEX",
|
"name": "REE_BIASED_COMPLEX",
|
||||||
|
|
@ -88,13 +104,23 @@ def get_model_configs() -> Dict:
|
||||||
"has_rho": True,
|
"has_rho": True,
|
||||||
"n_params": 12,
|
"n_params": 12,
|
||||||
"param_names": [
|
"param_names": [
|
||||||
"alpha_loss", "alpha_gain",
|
"alpha_loss",
|
||||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
"alpha_gain",
|
||||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
"forget_1",
|
||||||
"rho_BS", "rho_JP"
|
"forget_2",
|
||||||
|
"forget_3",
|
||||||
|
"forget_4",
|
||||||
|
"lambda_1",
|
||||||
|
"lambda_2",
|
||||||
|
"lambda_3",
|
||||||
|
"lambda_4",
|
||||||
|
"rho_BS",
|
||||||
|
"rho_JP",
|
||||||
],
|
],
|
||||||
"lower": np.concatenate([[-5, -5], np.full(4, -5), np.full(4, -3), [-10, -10]]),
|
"lower": np.concatenate(
|
||||||
"upper": np.concatenate([[5, 5], np.full(4, 5), np.full(4, 3), [10, 10]])
|
[[-5, -5], np.full(4, -5), np.full(4, -3), [-10, -10]]
|
||||||
|
),
|
||||||
|
"upper": np.concatenate([[5, 5], np.full(4, 5), np.full(4, 3), [10, 10]]),
|
||||||
},
|
},
|
||||||
"REE_LEARNING_SIMPLE": {
|
"REE_LEARNING_SIMPLE": {
|
||||||
"name": "REE_LEARNING_SIMPLE",
|
"name": "REE_LEARNING_SIMPLE",
|
||||||
|
|
@ -104,11 +130,15 @@ def get_model_configs() -> Dict:
|
||||||
"has_rho": False,
|
"has_rho": False,
|
||||||
"n_params": 6,
|
"n_params": 6,
|
||||||
"param_names": [
|
"param_names": [
|
||||||
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
"alpha_loss",
|
||||||
"forget", "lambda"
|
"alpha_gain",
|
||||||
|
"alpha_BS",
|
||||||
|
"alpha_JP",
|
||||||
|
"forget",
|
||||||
|
"lambda",
|
||||||
],
|
],
|
||||||
"lower": np.array([-5, -5, -5, -5, -5, -3]),
|
"lower": np.array([-5, -5, -5, -5, -5, -3]),
|
||||||
"upper": np.array([5, 5, 5, 5, 5, 3])
|
"upper": np.array([5, 5, 5, 5, 5, 3]),
|
||||||
},
|
},
|
||||||
"REE_LEARNING_COMPLEX": {
|
"REE_LEARNING_COMPLEX": {
|
||||||
"name": "REE_LEARNING_COMPLEX",
|
"name": "REE_LEARNING_COMPLEX",
|
||||||
|
|
@ -118,12 +148,21 @@ def get_model_configs() -> Dict:
|
||||||
"has_rho": False,
|
"has_rho": False,
|
||||||
"n_params": 12,
|
"n_params": 12,
|
||||||
"param_names": [
|
"param_names": [
|
||||||
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
"alpha_loss",
|
||||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
"alpha_gain",
|
||||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4"
|
"alpha_BS",
|
||||||
|
"alpha_JP",
|
||||||
|
"forget_1",
|
||||||
|
"forget_2",
|
||||||
|
"forget_3",
|
||||||
|
"forget_4",
|
||||||
|
"lambda_1",
|
||||||
|
"lambda_2",
|
||||||
|
"lambda_3",
|
||||||
|
"lambda_4",
|
||||||
],
|
],
|
||||||
"lower": np.concatenate([[-5, -5, -5, -5], np.full(4, -5), np.full(4, -3)]),
|
"lower": np.concatenate([[-5, -5, -5, -5], np.full(4, -5), np.full(4, -3)]),
|
||||||
"upper": np.concatenate([[5, 5, 5, 5], np.full(4, 5), np.full(4, 3)])
|
"upper": np.concatenate([[5, 5, 5, 5], np.full(4, 5), np.full(4, 3)]),
|
||||||
},
|
},
|
||||||
"REE_LEARNING_BIASED_SIMPLE": {
|
"REE_LEARNING_BIASED_SIMPLE": {
|
||||||
"name": "REE_LEARNING_BIASED_SIMPLE",
|
"name": "REE_LEARNING_BIASED_SIMPLE",
|
||||||
|
|
@ -133,11 +172,17 @@ def get_model_configs() -> Dict:
|
||||||
"has_rho": True,
|
"has_rho": True,
|
||||||
"n_params": 8,
|
"n_params": 8,
|
||||||
"param_names": [
|
"param_names": [
|
||||||
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
"alpha_loss",
|
||||||
"forget", "lambda", "rho_BS", "rho_JP"
|
"alpha_gain",
|
||||||
|
"alpha_BS",
|
||||||
|
"alpha_JP",
|
||||||
|
"forget",
|
||||||
|
"lambda",
|
||||||
|
"rho_BS",
|
||||||
|
"rho_JP",
|
||||||
],
|
],
|
||||||
"lower": np.array([-5, -5, -5, -5, -5, -3, -10, -10]),
|
"lower": np.array([-5, -5, -5, -5, -5, -3, -10, -10]),
|
||||||
"upper": np.array([5, 5, 5, 5, 5, 3, 10, 10])
|
"upper": np.array([5, 5, 5, 5, 5, 3, 10, 10]),
|
||||||
},
|
},
|
||||||
"REE_LEARNING_BIASED_COMPLEX": {
|
"REE_LEARNING_BIASED_COMPLEX": {
|
||||||
"name": "REE_LEARNING_BIASED_COMPLEX",
|
"name": "REE_LEARNING_BIASED_COMPLEX",
|
||||||
|
|
@ -147,14 +192,28 @@ def get_model_configs() -> Dict:
|
||||||
"has_rho": True,
|
"has_rho": True,
|
||||||
"n_params": 14,
|
"n_params": 14,
|
||||||
"param_names": [
|
"param_names": [
|
||||||
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
"alpha_loss",
|
||||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
"alpha_gain",
|
||||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
"alpha_BS",
|
||||||
"rho_BS", "rho_JP"
|
"alpha_JP",
|
||||||
|
"forget_1",
|
||||||
|
"forget_2",
|
||||||
|
"forget_3",
|
||||||
|
"forget_4",
|
||||||
|
"lambda_1",
|
||||||
|
"lambda_2",
|
||||||
|
"lambda_3",
|
||||||
|
"lambda_4",
|
||||||
|
"rho_BS",
|
||||||
|
"rho_JP",
|
||||||
],
|
],
|
||||||
"lower": np.concatenate([[-5, -5, -5, -5], np.full(4, -5), np.full(4, -3), [-10, -10]]),
|
"lower": np.concatenate(
|
||||||
"upper": np.concatenate([[5, 5, 5, 5], np.full(4, 5), np.full(4, 3), [10, 10]])
|
[[-5, -5, -5, -5], np.full(4, -5), np.full(4, -3), [-10, -10]]
|
||||||
}
|
),
|
||||||
|
"upper": np.concatenate(
|
||||||
|
[[5, 5, 5, 5], np.full(4, 5), np.full(4, 3), [10, 10]]
|
||||||
|
),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -162,26 +221,31 @@ def get_model_configs() -> Dict:
|
||||||
# MODÈLE Q-LEARNING GÉNÉRIQUE
|
# MODÈLE Q-LEARNING GÉNÉRIQUE
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
def qlearning_generic(params: np.ndarray, data: pd.DataFrame, model_config: Dict,
|
|
||||||
return_negLL: bool = True) -> float:
|
def qlearning_generic(
|
||||||
|
params: np.ndarray,
|
||||||
|
data: pd.DataFrame,
|
||||||
|
model_config: Dict,
|
||||||
|
return_negLL: bool = True,
|
||||||
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Modèle Q-learning générique avec support pour différentes architectures de paramètres.
|
Modèle Q-learning générique avec support pour différentes architectures de paramètres.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Vecteur de paramètres
|
params: Vecteur de paramètres
|
||||||
data: DataFrame avec colonnes 'choice', 'reward'
|
data: DataFrame avec colonnes 'choice', 'reward'
|
||||||
model_config: Configuration du modèle
|
model_config: Configuration du modèle
|
||||||
return_negLL: Si True, retourne -log-vraisemblance; sinon retourne log-vraisemblance
|
return_negLL: Si True, retourne -log-vraisemblance; sinon retourne log-vraisemblance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Valeur de la log-vraisemblance négative (ou positive selon return_negLL)
|
Valeur de la log-vraisemblance négative (ou positive selon return_negLL)
|
||||||
"""
|
"""
|
||||||
n_arms = 4
|
n_arms = 4
|
||||||
n_trials = len(data)
|
n_trials = len(data)
|
||||||
|
|
||||||
# Extraction des paramètres selon la configuration du modèle
|
# Extraction des paramètres selon la configuration du modèle
|
||||||
param_idx = 0
|
param_idx = 0
|
||||||
|
|
||||||
# ALPHA(S)
|
# ALPHA(S)
|
||||||
if model_config["n_alpha"] == 1:
|
if model_config["n_alpha"] == 1:
|
||||||
alpha_loss = alpha_gain = alpha_BS = alpha_JP = expit(params[param_idx])
|
alpha_loss = alpha_gain = alpha_BS = alpha_JP = expit(params[param_idx])
|
||||||
|
|
@ -198,60 +262,60 @@ def qlearning_generic(params: np.ndarray, data: pd.DataFrame, model_config: Dict
|
||||||
alpha_BS = expit(params[param_idx + 2])
|
alpha_BS = expit(params[param_idx + 2])
|
||||||
alpha_JP = expit(params[param_idx + 3])
|
alpha_JP = expit(params[param_idx + 3])
|
||||||
param_idx += 4
|
param_idx += 4
|
||||||
|
|
||||||
# FORGET(S)
|
# FORGET(S)
|
||||||
if model_config["n_forget"] == 1:
|
if model_config["n_forget"] == 1:
|
||||||
forget = np.full(n_arms, expit(params[param_idx]))
|
forget = np.full(n_arms, expit(params[param_idx]))
|
||||||
param_idx += 1
|
param_idx += 1
|
||||||
elif model_config["n_forget"] == 4:
|
elif model_config["n_forget"] == 4:
|
||||||
forget = expit(params[param_idx:(param_idx + 4)])
|
forget = expit(params[param_idx : (param_idx + 4)])
|
||||||
param_idx += 4
|
param_idx += 4
|
||||||
|
|
||||||
# LAMBDA(S)
|
# LAMBDA(S)
|
||||||
if model_config["n_lambda"] == 1:
|
if model_config["n_lambda"] == 1:
|
||||||
lambda_vals = np.full(n_arms, np.exp(params[param_idx]))
|
lambda_vals = np.full(n_arms, np.exp(params[param_idx]))
|
||||||
param_idx += 1
|
param_idx += 1
|
||||||
elif model_config["n_lambda"] == 4:
|
elif model_config["n_lambda"] == 4:
|
||||||
lambda_vals = np.exp(params[param_idx:(param_idx + 4)])
|
lambda_vals = np.exp(params[param_idx : (param_idx + 4)])
|
||||||
param_idx += 4
|
param_idx += 4
|
||||||
|
|
||||||
# RHO(S) - Biais pour événements rares
|
# RHO(S) - Biais pour événements rares
|
||||||
if model_config["has_rho"]:
|
if model_config["has_rho"]:
|
||||||
rho_BS = params[param_idx]
|
rho_BS = params[param_idx]
|
||||||
rho_JP = params[param_idx + 1]
|
rho_JP = params[param_idx + 1]
|
||||||
else:
|
else:
|
||||||
rho_BS = rho_JP = 0
|
rho_BS = rho_JP = 0
|
||||||
|
|
||||||
# Initialisation des Q-values
|
# Initialisation des Q-values
|
||||||
Q = np.zeros(n_arms)
|
Q = np.zeros(n_arms)
|
||||||
log_lik = 0.0
|
log_lik = 0.0
|
||||||
|
|
||||||
for t in range(n_trials):
|
for t in range(n_trials):
|
||||||
choice = int(data.iloc[t]["choice"])
|
choice = int(data.iloc[t]["choice"])
|
||||||
reward = data.iloc[t]["reward"]
|
reward = data.iloc[t]["reward"]
|
||||||
|
|
||||||
# Calcul des valeurs subjectives V(t)
|
# Calcul des valeurs subjectives V(t)
|
||||||
V = lambda_vals * Q
|
V = lambda_vals * Q
|
||||||
|
|
||||||
# Ajout des biais pour événements rares si le modèle le permet
|
# Ajout des biais pour événements rares si le modèle le permet
|
||||||
if model_config["has_rho"]:
|
if model_config["has_rho"]:
|
||||||
V[0] += rho_JP # antifragile
|
V[0] += rho_JP # antifragile
|
||||||
V[1] += rho_BS # fragile
|
V[1] += rho_BS # fragile
|
||||||
V[3] += rho_BS + rho_JP # vulnerable
|
V[3] += rho_BS + rho_JP # vulnerable
|
||||||
|
|
||||||
# Softmax
|
# Softmax
|
||||||
V_max = np.max(V)
|
V_max = np.max(V)
|
||||||
exp_V = np.exp(V - V_max)
|
exp_V = np.exp(V - V_max)
|
||||||
probs = exp_V / np.sum(exp_V)
|
probs = exp_V / np.sum(exp_V)
|
||||||
probs = np.maximum(probs, 1e-10)
|
probs = np.maximum(probs, 1e-10)
|
||||||
probs = probs / np.sum(probs)
|
probs = probs / np.sum(probs)
|
||||||
|
|
||||||
# Log-likelihood
|
# Log-likelihood
|
||||||
log_lik += np.log(probs[choice])
|
log_lik += np.log(probs[choice])
|
||||||
|
|
||||||
# Mise à jour Q-learning
|
# Mise à jour Q-learning
|
||||||
Q_new = Q.copy()
|
Q_new = Q.copy()
|
||||||
|
|
||||||
# Choix de l'alpha approprié
|
# Choix de l'alpha approprié
|
||||||
if reward == -3000:
|
if reward == -3000:
|
||||||
alpha_used = alpha_BS
|
alpha_used = alpha_BS
|
||||||
|
|
@ -261,96 +325,108 @@ def qlearning_generic(params: np.ndarray, data: pd.DataFrame, model_config: Dict
|
||||||
alpha_used = alpha_loss
|
alpha_used = alpha_loss
|
||||||
else:
|
else:
|
||||||
alpha_used = alpha_gain
|
alpha_used = alpha_gain
|
||||||
|
|
||||||
# Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t))
|
# Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t))
|
||||||
Q_new[choice] = Q[choice] + alpha_used * (reward - Q[choice])
|
Q_new[choice] = Q[choice] + alpha_used * (reward - Q[choice])
|
||||||
|
|
||||||
# Options non choisies : Q(t+1) = Q(t) * (1 - f)
|
# Options non choisies : Q(t+1) = Q(t) * (1 - f)
|
||||||
not_chosen = np.setdiff1d(np.arange(n_arms), [choice])
|
not_chosen = np.setdiff1d(np.arange(n_arms), [choice])
|
||||||
Q_new[not_chosen] = Q[not_chosen] * (1 - forget[not_chosen])
|
Q_new[not_chosen] = Q[not_chosen] * (1 - forget[not_chosen])
|
||||||
|
|
||||||
Q = Q_new
|
Q = Q_new
|
||||||
|
|
||||||
if return_negLL:
|
if return_negLL:
|
||||||
return -log_lik
|
return -log_lik
|
||||||
else:
|
else:
|
||||||
return log_lik
|
return log_lik
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# OPTIMISATION AVEC PYVBMC
|
# OPTIMISATION AVEC PYVBMC
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
def fit_participant_pyvbmc(participant_data: pd.DataFrame, model_config: Dict,
|
|
||||||
verbose: bool = True) -> Dict:
|
def fit_participant_pyvbmc(
|
||||||
|
participant_data: pd.DataFrame, model_config: Dict, verbose: bool = True
|
||||||
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Optimise les paramètres du modèle pour un participant utilisant PyVBMC.
|
Optimise les paramètres du modèle pour un participant utilisant PyVBMC.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
participant_data: Données du participant
|
participant_data: Données du participant
|
||||||
model_config: Configuration du modèle
|
model_config: Configuration du modèle
|
||||||
verbose: Affiche les progressions
|
verbose: Affiche les progressions
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionnaire avec les résultats d'optimisation
|
Dictionnaire avec les résultats d'optimisation
|
||||||
"""
|
"""
|
||||||
if not PYVBMC_AVAILABLE:
|
if not PYVBMC_AVAILABLE:
|
||||||
raise RuntimeError("PyVBMC n'est pas installé. Installez avec: pip install pyvbmc")
|
raise RuntimeError(
|
||||||
|
"PyVBMC n'est pas installé. Installez avec: pip install pyvbmc"
|
||||||
|
)
|
||||||
|
|
||||||
# Définition de la fonction de log-densité pour PyVBMC
|
# Définition de la fonction de log-densité pour PyVBMC
|
||||||
def log_posterior(params_array):
|
def log_likelihood(params_array):
|
||||||
"""PyVBMC maximise, donc on retourne -negLL."""
|
"""PyVBMC maximise, donc on retourne -negLL."""
|
||||||
params = np.asarray(params_array).flatten()
|
params = np.asarray(params_array).flatten()
|
||||||
negLL = qlearning_generic(params, participant_data, model_config, return_negLL=True)
|
negLL = qlearning_generic(
|
||||||
|
params, participant_data, model_config, return_negLL=True
|
||||||
|
)
|
||||||
return -negLL
|
return -negLL
|
||||||
|
|
||||||
# Point de départ (milieu des bornes)
|
# Point de départ (milieu des bornes)
|
||||||
x0 = (model_config["lower"] + model_config["upper"]) / 2
|
x0 = (model_config["lower"] + model_config["upper"]) / 2
|
||||||
|
|
||||||
# Bornes plausibles (25%-75% de la plage)
|
# Bornes plausibles (25%-75% de la plage)
|
||||||
plb = model_config["lower"] + 0.25 * (model_config["upper"] - model_config["lower"])
|
plb = model_config["lower"] + 0.25 * (model_config["upper"] - model_config["lower"])
|
||||||
pub = model_config["upper"] - 0.25 * (model_config["upper"] - model_config["lower"])
|
pub = model_config["upper"] - 0.25 * (model_config["upper"] - model_config["lower"])
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Starting VBMC optimization...")
|
print(f" Starting VBMC optimization...")
|
||||||
print(f" Initial parameters: {x0}")
|
print(f" Initial parameters: {x0}")
|
||||||
print(f" Lower bounds: {model_config['lower']}")
|
print(f" Lower bounds: {model_config['lower']}")
|
||||||
print(f" Upper bounds: {model_config['upper']}")
|
print(f" Upper bounds: {model_config['upper']}")
|
||||||
|
|
||||||
# Initialisation et optimisation de VBMC
|
# Initialisation et optimisation de VBMC
|
||||||
vbmc = VBMC(
|
vbmc = VBMC(
|
||||||
log_posterior,
|
log_likelihood,
|
||||||
x0,
|
x0,
|
||||||
model_config["lower"],
|
model_config["lower"],
|
||||||
model_config["upper"],
|
model_config["upper"],
|
||||||
plb,
|
plb,
|
||||||
pub,
|
pub,
|
||||||
options={
|
options={
|
||||||
"verbose": 0 if not verbose else 1,
|
# "verbose": 0 if not verbose else 1,
|
||||||
"display": "off",
|
"display": "off",
|
||||||
}
|
},
|
||||||
|
prior=UniformBox(
|
||||||
|
a=model_config["lower"], b=model_config["upper"], D=model_config["n_params"]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
vp, results = vbmc.optimize()
|
vp, results = vbmc.optimize()
|
||||||
|
|
||||||
# Extraction des statistiques
|
# Extraction des statistiques
|
||||||
posterior_mean, posterior_cov = vp.moments()
|
posterior_mean, posterior_cov = vp.moments(orig_flag=True, cov_flag=True)
|
||||||
posterior_mean = np.asarray(posterior_mean).flatten()
|
posterior_mean = np.asarray(posterior_mean).flatten()
|
||||||
posterior_sd = np.sqrt(np.diag(posterior_cov))
|
posterior_sd = np.sqrt(np.diag(posterior_cov))
|
||||||
|
|
||||||
# ELBO et autres métriques
|
# ELBO et autres métriques
|
||||||
elbo = results["elbo"]
|
elbo = results["elbo"]
|
||||||
elbo_sd = results.get("elbo_sd", np.nan)
|
elbo_sd = results.get("elbo_sd", np.nan)
|
||||||
n_iterations = results.get("iterations", np.nan)
|
n_iterations = results.get("iterations", np.nan)
|
||||||
|
|
||||||
# Calcul du negLL avec la posterior mean
|
# Calcul du negLL avec la posterior mean
|
||||||
negLL = qlearning_generic(posterior_mean, participant_data, model_config, return_negLL=True)
|
negLL = qlearning_generic(
|
||||||
|
posterior_mean, participant_data, model_config, return_negLL=True
|
||||||
|
)
|
||||||
n_obs = len(participant_data)
|
n_obs = len(participant_data)
|
||||||
|
|
||||||
# Calcul des critères d'information
|
# Calcul des critères d'information
|
||||||
aic = 2 * negLL + 2 * model_config["n_params"]
|
aic = 2 * negLL + 2 * model_config["n_params"]
|
||||||
bic = 2 * negLL + model_config["n_params"] * np.log(n_obs)
|
bic = 2 * negLL + model_config["n_params"] * np.log(n_obs)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"model": model_config["name"],
|
"model": model_config["name"],
|
||||||
"n_params": model_config["n_params"],
|
"n_params": model_config["n_params"],
|
||||||
|
|
@ -365,74 +441,81 @@ def fit_participant_pyvbmc(participant_data: pd.DataFrame, model_config: Dict,
|
||||||
"posterior_mean": posterior_mean,
|
"posterior_mean": posterior_mean,
|
||||||
"posterior_sd": posterior_sd,
|
"posterior_sd": posterior_sd,
|
||||||
"vp": vp,
|
"vp": vp,
|
||||||
"results": results
|
"results": results,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Ajout des paramètres estimés
|
# Ajout des paramètres estimés
|
||||||
for i, param_name in enumerate(model_config["param_names"]):
|
for i, param_name in enumerate(model_config["param_names"]):
|
||||||
result[param_name] = posterior_mean[i]
|
result[param_name] = posterior_mean[i]
|
||||||
result[f"sd_{param_name}"] = posterior_sd[i]
|
result[f"sd_{param_name}"] = posterior_sd[i]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def fit_participant_deoptim(participant_data: pd.DataFrame, model_config: Dict,
|
def fit_participant_deoptim(
|
||||||
n_runs: int = 5, verbose: bool = True) -> Dict:
|
participant_data: pd.DataFrame,
|
||||||
|
model_config: Dict,
|
||||||
|
n_runs: int = 5,
|
||||||
|
verbose: bool = True,
|
||||||
|
n_workers: int = 1,
|
||||||
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Optimise les paramètres du modèle pour un participant utilisant minimisation scipy.
|
Optimise les paramètres du modèle pour un participant utilisant minimisation scipy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
participant_data: Données du participant
|
participant_data: Données du participant
|
||||||
model_config: Configuration du modèle
|
model_config: Configuration du modèle
|
||||||
n_runs: Nombre de runs avec différents points de départ
|
n_runs: Nombre de runs avec différents points de départ
|
||||||
verbose: Affiche les progressions
|
verbose: Affiche les progressions
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionnaire avec les résultats d'optimisation
|
Dictionnaire avec les résultats d'optimisation
|
||||||
"""
|
"""
|
||||||
from scipy.optimize import differential_evolution
|
from scipy.optimize import differential_evolution
|
||||||
|
|
||||||
best_result = None
|
best_result = None
|
||||||
best_negLL = np.inf
|
best_negLL = np.inf
|
||||||
all_negLLs = []
|
all_negLLs = []
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Running {n_runs} optimization runs...")
|
print(f" Running {n_runs} optimization runs...")
|
||||||
|
|
||||||
for run in range(n_runs):
|
for run in range(n_runs):
|
||||||
np.random.seed(1000 * hash(model_config["name"]) % (2**31) + run)
|
np.random.seed(1000 * hash(model_config["name"]) % (2**31) + run)
|
||||||
|
|
||||||
def objective(params):
|
def objective(params):
|
||||||
return qlearning_generic(params, participant_data, model_config, return_negLL=True)
|
return qlearning_generic(
|
||||||
|
params, participant_data, model_config, return_negLL=True
|
||||||
|
)
|
||||||
|
|
||||||
result = differential_evolution(
|
result = differential_evolution(
|
||||||
objective,
|
objective,
|
||||||
bounds=list(zip(model_config["lower"], model_config["upper"])),
|
bounds=list(zip(model_config["lower"], model_config["upper"])),
|
||||||
maxiter=200,
|
maxiter=200,
|
||||||
popsize=max(50, model_config["n_params"] * 10),
|
popsize=max(50, model_config["n_params"] * 10),
|
||||||
seed=1000 * hash(model_config["name"]) % (2**31) + run,
|
rng=1000 * hash(model_config["name"]) % (2**31) + run,
|
||||||
workers=1,
|
workers=n_workers,
|
||||||
updating="deferred"
|
updating="deferred",
|
||||||
)
|
)
|
||||||
|
|
||||||
all_negLLs.append(result.fun)
|
all_negLLs.append(result.fun)
|
||||||
|
|
||||||
if result.fun < best_negLL:
|
if result.fun < best_negLL:
|
||||||
best_negLL = result.fun
|
best_negLL = result.fun
|
||||||
best_result = result
|
best_result = result
|
||||||
|
|
||||||
posterior_mean = best_result.x
|
posterior_mean = best_result.x
|
||||||
negLL = best_negLL
|
negLL = best_negLL
|
||||||
n_obs = len(participant_data)
|
n_obs = len(participant_data)
|
||||||
|
|
||||||
# Calcul des critères d'information
|
# Calcul des critères d'information
|
||||||
aic = 2 * negLL + 2 * model_config["n_params"]
|
aic = 2 * negLL + 2 * model_config["n_params"]
|
||||||
bic = 2 * negLL + model_config["n_params"] * np.log(n_obs)
|
bic = 2 * negLL + model_config["n_params"] * np.log(n_obs)
|
||||||
|
|
||||||
# Statistiques de convergence
|
# Statistiques de convergence
|
||||||
convergence_sd = np.std(all_negLLs)
|
convergence_sd = np.std(all_negLLs)
|
||||||
convergence_range = np.max(all_negLLs) - np.min(all_negLLs)
|
convergence_range = np.max(all_negLLs) - np.min(all_negLLs)
|
||||||
|
|
||||||
result_dict = {
|
result_dict = {
|
||||||
"model": model_config["name"],
|
"model": model_config["name"],
|
||||||
"n_params": model_config["n_params"],
|
"n_params": model_config["n_params"],
|
||||||
|
|
@ -446,11 +529,11 @@ def fit_participant_deoptim(participant_data: pd.DataFrame, model_config: Dict,
|
||||||
"method": "Differential Evolution",
|
"method": "Differential Evolution",
|
||||||
"posterior_mean": posterior_mean,
|
"posterior_mean": posterior_mean,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Ajout des paramètres estimés
|
# Ajout des paramètres estimés
|
||||||
for i, param_name in enumerate(model_config["param_names"]):
|
for i, param_name in enumerate(model_config["param_names"]):
|
||||||
result_dict[param_name] = posterior_mean[i]
|
result_dict[param_name] = posterior_mean[i]
|
||||||
|
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -458,64 +541,72 @@ def fit_participant_deoptim(participant_data: pd.DataFrame, model_config: Dict,
|
||||||
# OPTIMISATION POUR TOUS LES PARTICIPANTS ET MODÈLES
|
# OPTIMISATION POUR TOUS LES PARTICIPANTS ET MODÈLES
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
def fit_all_participants(data: pd.DataFrame, models_to_fit: Optional[List[str]] = None,
|
|
||||||
method: str = "VBMC", n_participants: Optional[int] = None,
|
def fit_all_participants(
|
||||||
verbose: bool = True) -> Dict[str, List[Dict]]:
|
data: pd.DataFrame,
|
||||||
|
models_to_fit: Optional[List[str]] = None,
|
||||||
|
method: str = "VBMC",
|
||||||
|
n_participants: Optional[int] = None,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> Dict[str, List[Dict]]:
|
||||||
"""
|
"""
|
||||||
Ajuste tous les modèles pour tous les participants.
|
Ajuste tous les modèles pour tous les participants.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: DataFrame avec les données de tous les participants
|
data: DataFrame avec les données de tous les participants
|
||||||
models_to_fit: Liste des noms de modèles à ajuster (None = tous)
|
models_to_fit: Liste des noms de modèles à ajuster (None = tous)
|
||||||
method: Méthode d'optimisation ("VBMC" ou "differential_evolution")
|
method: Méthode d'optimisation ("VBMC" ou "differential_evolution")
|
||||||
n_participants: Nombre de participants à traiter (None = tous)
|
n_participants: Nombre de participants à traiter (None = tous)
|
||||||
verbose: Affiche les progressions
|
verbose: Affiche les progressions
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionnaire avec les résultats par modèle
|
Dictionnaire avec les résultats par modèle
|
||||||
"""
|
"""
|
||||||
model_configs = get_model_configs()
|
model_configs = get_model_configs()
|
||||||
|
|
||||||
if models_to_fit is not None:
|
if models_to_fit is not None:
|
||||||
model_configs = {k: v for k, v in model_configs.items() if k in models_to_fit}
|
model_configs = {k: v for k, v in model_configs.items() if k in models_to_fit}
|
||||||
|
|
||||||
participants = data["participant"].unique()
|
participants = data["participant"].unique()
|
||||||
if n_participants is not None:
|
if n_participants is not None:
|
||||||
participants = participants[:n_participants]
|
participants = participants[:n_participants]
|
||||||
|
|
||||||
all_results = {}
|
all_results = {}
|
||||||
|
|
||||||
for model_name, model_config in model_configs.items():
|
for model_name, model_config in model_configs.items():
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"\n=== Fitting model: {model_name} ===")
|
print(f"\n=== Fitting model: {model_name} ===")
|
||||||
|
|
||||||
model_results = []
|
model_results = []
|
||||||
|
|
||||||
for participant_id in participants:
|
for participant_id in participants:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Participant: {participant_id}")
|
print(f" Participant: {participant_id}")
|
||||||
|
|
||||||
participant_data = data[data["participant"] == participant_id].copy()
|
participant_data = data[data["participant"] == participant_id].copy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if method == "VBMC":
|
if method == "VBMC":
|
||||||
result = fit_participant_pyvbmc(participant_data, model_config, verbose=False)
|
result = fit_participant_pyvbmc(
|
||||||
|
participant_data, model_config, verbose=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = fit_participant_deoptim(participant_data, model_config,
|
result = fit_participant_deoptim(
|
||||||
n_runs=5, verbose=False)
|
participant_data, model_config, n_runs=5, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
result["participant"] = participant_id
|
result["participant"] = participant_id
|
||||||
model_results.append(result)
|
model_results.append(result)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" negLL: {result['negLL']:.2f}, BIC: {result['BIC']:.2f}")
|
print(f" negLL: {result['negLL']:.2f}, BIC: {result['BIC']:.2f}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ERROR: {str(e)}")
|
print(f" ERROR: {str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
all_results[model_name] = model_results
|
all_results[model_name] = model_results
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -523,25 +614,26 @@ def fit_all_participants(data: pd.DataFrame, models_to_fit: Optional[List[str]]
|
||||||
# COMPARAISON DES MODÈLES
|
# COMPARAISON DES MODÈLES
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def compare_models(all_results: Dict[str, List[Dict]]) -> Dict:
|
def compare_models(all_results: Dict[str, List[Dict]]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Compare les modèles et sélectionne les meilleurs par participant.
|
Compare les modèles et sélectionne les meilleurs par participant.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
all_results: Résultats de l'ajustement de tous les modèles
|
all_results: Résultats de l'ajustement de tous les modèles
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionnaire avec comparaisons globales et par participant
|
Dictionnaire avec comparaisons globales et par participant
|
||||||
"""
|
"""
|
||||||
# Comparaison globale
|
# Comparaison globale
|
||||||
global_comparison = []
|
global_comparison = []
|
||||||
|
|
||||||
for model_name, results in all_results.items():
|
for model_name, results in all_results.items():
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
results_df = pd.DataFrame(results)
|
results_df = pd.DataFrame(results)
|
||||||
|
|
||||||
comparison_row = {
|
comparison_row = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"n_params": results[0]["n_params"],
|
"n_params": results[0]["n_params"],
|
||||||
|
|
@ -553,34 +645,38 @@ def compare_models(all_results: Dict[str, List[Dict]]) -> Dict:
|
||||||
"total_BIC": results_df["BIC"].sum(),
|
"total_BIC": results_df["BIC"].sum(),
|
||||||
}
|
}
|
||||||
global_comparison.append(comparison_row)
|
global_comparison.append(comparison_row)
|
||||||
|
|
||||||
global_comparison_df = pd.DataFrame(global_comparison).sort_values("total_BIC")
|
global_comparison_df = pd.DataFrame(global_comparison).sort_values("total_BIC")
|
||||||
|
|
||||||
print("\n=== GLOBAL MODEL COMPARISON ===")
|
print("\n=== GLOBAL MODEL COMPARISON ===")
|
||||||
print(global_comparison_df.to_string(index=False))
|
print(global_comparison_df.to_string(index=False))
|
||||||
|
|
||||||
# Meilleur modèle par participant
|
# Meilleur modèle par participant
|
||||||
all_results_list = []
|
all_results_list = []
|
||||||
for model_name, results in all_results.items():
|
for model_name, results in all_results.items():
|
||||||
for result in results:
|
for result in results:
|
||||||
all_results_list.append({
|
all_results_list.append(
|
||||||
"participant": result["participant"],
|
{
|
||||||
"model": model_name,
|
"participant": result["participant"],
|
||||||
"BIC": result["BIC"],
|
"model": model_name,
|
||||||
"AIC": result["AIC"],
|
"BIC": result["BIC"],
|
||||||
"negLL": result["negLL"]
|
"AIC": result["AIC"],
|
||||||
})
|
"negLL": result["negLL"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
all_results_df = pd.DataFrame(all_results_list)
|
all_results_df = pd.DataFrame(all_results_list)
|
||||||
best_per_participant = all_results_df.loc[all_results_df.groupby("participant")["BIC"].idxmin()]
|
best_per_participant = all_results_df.loc[
|
||||||
|
all_results_df.groupby("participant")["BIC"].idxmin()
|
||||||
|
]
|
||||||
|
|
||||||
print("\n=== BEST MODELS PER PARTICIPANT ===")
|
print("\n=== BEST MODELS PER PARTICIPANT ===")
|
||||||
print(best_per_participant["model"].value_counts())
|
print(best_per_participant["model"].value_counts())
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"global_comparison": global_comparison_df,
|
"global_comparison": global_comparison_df,
|
||||||
"best_per_participant": best_per_participant,
|
"best_per_participant": best_per_participant,
|
||||||
"all_results": all_results
|
"all_results": all_results,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -588,43 +684,161 @@ def compare_models(all_results: Dict[str, List[Dict]]) -> Dict:
|
||||||
# SAUVEGARDE DES RÉSULTATS
|
# SAUVEGARDE DES RÉSULTATS
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
def save_results(all_results: Dict[str, List[Dict]], output_dir: str = "results") -> None:
|
|
||||||
|
def save_results(
|
||||||
|
all_results: Dict[str, List[Dict]], output_dir: str = "results"
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Sauvegarde les résultats d'optimisation en CSV.
|
Sauvegarde les résultats d'optimisation en CSV.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
all_results: Résultats de l'ajustement
|
all_results: Résultats de l'ajustement
|
||||||
output_dir: Répertoire de sortie
|
output_dir: Répertoire de sortie
|
||||||
"""
|
"""
|
||||||
output_path = Path(output_dir)
|
output_path = Path(output_dir)
|
||||||
output_path.mkdir(exist_ok=True)
|
output_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
for model_name, results in all_results.items():
|
for model_name, results in all_results.items():
|
||||||
results_df = pd.DataFrame(results)
|
results_df = pd.DataFrame(results)
|
||||||
|
|
||||||
# Garder seulement les colonnes numériques pour le CSV
|
# Garder seulement les colonnes numériques pour le CSV
|
||||||
cols_to_keep = [col for col in results_df.columns
|
cols_to_keep = [
|
||||||
if col not in ["vp", "results", "posterior_mean", "posterior_sd"]]
|
col
|
||||||
|
for col in results_df.columns
|
||||||
|
if col not in ["vp", "results", "posterior_mean", "posterior_sd"]
|
||||||
|
]
|
||||||
results_df[cols_to_keep].to_csv(
|
results_df[cols_to_keep].to_csv(
|
||||||
output_path / f"results_{model_name}.csv",
|
output_path / f"results_{model_name}.csv", index=False
|
||||||
index=False
|
|
||||||
)
|
)
|
||||||
print(f"Saved: results_{model_name}.csv")
|
print(f"Saved: results_{model_name}.csv")
|
||||||
|
|
||||||
|
|
||||||
|
def fit_vbmc_and_diffEvol(
|
||||||
|
participant_data: pd.DataFrame,
|
||||||
|
model_config: Dict,
|
||||||
|
n_deoptim_runs: int = 5,
|
||||||
|
n_workers: int = 1,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> Tuple[Dict, Dict]:
|
||||||
|
"""
|
||||||
|
Ajuste un modèle à l'aide de PyVBMC et Differential Evolution pour comparaison.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
participant_data: Données du participant
|
||||||
|
model_config: Configuration du modèle
|
||||||
|
n_deoptim_runs: Nombre de runs pour Differential Evolution
|
||||||
|
verbose: Affiche les progressions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple avec les résultats VBMC et Differential Evolution
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
print(f" Fitting with VBMC")
|
||||||
|
vbmc_result = fit_participant_pyvbmc(
|
||||||
|
participant_data, model_config, verbose=verbose
|
||||||
|
)
|
||||||
|
if verbose:
|
||||||
|
print(f" Fitting with Differential Evolution")
|
||||||
|
deoptim_result = fit_participant_deoptim(
|
||||||
|
participant_data,
|
||||||
|
model_config,
|
||||||
|
n_runs=n_deoptim_runs,
|
||||||
|
n_workers=n_workers,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
return vbmc_result, deoptim_result
|
||||||
|
|
||||||
|
|
||||||
|
def fit_all_participants_both_methods(
|
||||||
|
data: pd.DataFrame,
|
||||||
|
models_to_fit: Optional[List[str]] = None,
|
||||||
|
n_participants: Optional[int] = None,
|
||||||
|
n_deoptim_runs: int = 5,
|
||||||
|
n_workers: int = 1,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> Dict[str, List[Dict]]:
|
||||||
|
"""
|
||||||
|
Ajuste tous les modèles pour tous les participants avec les deux méthodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame avec les données de tous les participants
|
||||||
|
models_to_fit: Liste des noms de modèles à ajuster (None = tous)
|
||||||
|
n_participants: Nombre de participants à traiter (None = tous)
|
||||||
|
n_deoptim_runs: Nombre de runs pour Differential Evolution
|
||||||
|
verbose: Affiche les progressions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionnaire avec les résultats par modèle et méthode
|
||||||
|
"""
|
||||||
|
model_configs = get_model_configs()
|
||||||
|
|
||||||
|
if models_to_fit is not None:
|
||||||
|
model_configs = {k: v for k, v in model_configs.items() if k in models_to_fit}
|
||||||
|
|
||||||
|
participants = data["participant"].unique()
|
||||||
|
if n_participants is not None:
|
||||||
|
participants = participants[:n_participants]
|
||||||
|
|
||||||
|
all_results = {}
|
||||||
|
|
||||||
|
for model_name, model_config in model_configs.items():
|
||||||
|
if verbose:
|
||||||
|
print(f"\n=== Fitting model: {model_name} ===")
|
||||||
|
|
||||||
|
model_results = []
|
||||||
|
|
||||||
|
for participant_id in participants:
|
||||||
|
if verbose:
|
||||||
|
print(f" Participant: {participant_id}")
|
||||||
|
|
||||||
|
participant_data = data[data["participant"] == participant_id].copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
vbmc_result, deoptim_result = fit_vbmc_and_diffEvol(
|
||||||
|
participant_data,
|
||||||
|
model_config,
|
||||||
|
n_deoptim_runs=n_deoptim_runs,
|
||||||
|
n_workers=n_workers,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
vbmc_result["participant"] = participant_id
|
||||||
|
deoptim_result["participant"] = participant_id
|
||||||
|
|
||||||
|
model_results.append(
|
||||||
|
{"VBMC": vbmc_result, "Differential_Evolution": deoptim_result}
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f" VBMC negLL: {vbmc_result['negLL']:.2f}, BIC: {vbmc_result['BIC']:.2f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" DE negLL: {deoptim_result['negLL']:.2f}, BIC: {deoptim_result['BIC']:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERROR: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_results[model_name] = model_results
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# EXEMPLE D'UTILISATION
|
# EXEMPLE D'UTILISATION
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("=== PyVBMC Optimization for Q-Learning Models ===\n")
|
print("=== Optimization for Q-Learning Models ===\n")
|
||||||
|
|
||||||
# Préparation des données
|
# Préparation des données
|
||||||
print("Loading data...")
|
print("Loading data...")
|
||||||
data_for_fitting = all_participant_data[["participant", "choice", "reward"]].copy()
|
data_for_fitting = all_participant_data[["participant", "choice", "reward"]].copy()
|
||||||
print(f" Total participants: {data_for_fitting['participant'].nunique()}")
|
print(f" Total participants: {data_for_fitting['participant'].nunique()}")
|
||||||
print(f" Total trials: {len(data_for_fitting)}")
|
print(f" Total trials: {len(data_for_fitting)}")
|
||||||
|
|
||||||
# Ajustement des modèles
|
# Ajustement des modèles
|
||||||
method = "differential_evolution" # "VBMC" ou "differential_evolution"
|
method = "differential_evolution" # "VBMC" ou "differential_evolution"
|
||||||
if PYVBMC_AVAILABLE:
|
if PYVBMC_AVAILABLE:
|
||||||
|
|
@ -632,25 +846,45 @@ if __name__ == "__main__":
|
||||||
print(f" PyVBMC available - using {method}")
|
print(f" PyVBMC available - using {method}")
|
||||||
else:
|
else:
|
||||||
print(f" PyVBMC not available - using {method}")
|
print(f" PyVBMC not available - using {method}")
|
||||||
|
|
||||||
# Ajustement de quelques modèles pour test
|
# Ajustement de quelques modèles pour test
|
||||||
models_to_fit = ["HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE"]
|
models_to_fit = [
|
||||||
|
"HOMOGENEOUS",
|
||||||
|
"GAIN_LOSS",
|
||||||
|
"REE_BIASED_SIMPLE",
|
||||||
|
"REE_BIASED_COMPLEX",
|
||||||
|
"REE_LEARNING_SIMPLE",
|
||||||
|
"REE_LEARNING_COMPLEX",
|
||||||
|
"REE_LEARNING_BIASED_SIMPLE",
|
||||||
|
"REE_LEARNING_BIASED_COMPLEX",
|
||||||
|
]
|
||||||
|
|
||||||
|
# all_results = fit_all_participants_both_methods(
|
||||||
|
# data_for_fitting,
|
||||||
|
# models_to_fit=models_to_fit,
|
||||||
|
# # method=method,
|
||||||
|
# n_participants=2, # Set to a number to limit for testing
|
||||||
|
# n_workers=1,
|
||||||
|
# verbose=True,
|
||||||
|
# )
|
||||||
|
|
||||||
all_results = fit_all_participants(
|
all_results = fit_all_participants(
|
||||||
data_for_fitting,
|
data_for_fitting,
|
||||||
models_to_fit=models_to_fit,
|
models_to_fit=models_to_fit,
|
||||||
method=method,
|
method=method,
|
||||||
n_participants=2, # Set to a number to limit for testing
|
n_participants=1, # Set to a number to limit for testing
|
||||||
verbose=True
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Comparaison des modèles
|
# Comparaison des modèles
|
||||||
comparison = compare_models(all_results)
|
comparison = compare_models(all_results)
|
||||||
|
|
||||||
# Sauvegarde des résultats
|
# Sauvegarde des résultats
|
||||||
print("\nSaving results...")
|
print("\nSaving results...")
|
||||||
save_results(all_results)
|
save_results(all_results)
|
||||||
comparison["global_comparison"].to_csv("results/global_comparison.csv", index=False)
|
comparison["global_comparison"].to_csv("results/global_comparison.csv", index=False)
|
||||||
comparison["best_per_participant"].to_csv("results/best_models.csv", index=False)
|
comparison["best_per_participant"].to_csv("results/best_models.csv", index=False)
|
||||||
|
|
||||||
print("\nDone!")
|
print("\nDone!")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue