Adding plots and sigmoid for relevant params

This commit is contained in:
Louis 2025-12-10 22:19:03 +01:00
parent 28efe2e515
commit 8aa1d27675

View file

@ -348,7 +348,7 @@ def qlearning_generic(
def fit_participant_pyvbmc(
participant_data: pd.DataFrame, model_config: Dict, verbose: bool = True
participant_data: pd.DataFrame, model_config: Dict, verbose: bool = True, plot : bool = True
) -> Dict:
"""
Optimise les paramètres du modèle pour un participant utilisant PyVBMC.
@ -399,6 +399,8 @@ def fit_participant_pyvbmc(
options={
# "verbose": 0 if not verbose else 1,
"display": "off",
"plot": plot,
"log_file_name": None,
},
prior=UniformBox(
a=model_config["lower"], b=model_config["upper"], D=model_config["n_params"]
@ -446,8 +448,13 @@ def fit_participant_pyvbmc(
# Ajout des paramètres estimés
for i, param_name in enumerate(model_config["param_names"]):
result[param_name] = posterior_mean[i]
result[f"sd_{param_name}"] = posterior_sd[i]
# Here we use expit for parameters that were originally bounded between 0 and 1
if param_name.startswith("alpha") or param_name.startswith("forget"):
result[param_name] = expit(posterior_mean[i])
result[f"sd_{param_name}"] = expit(posterior_sd[i])
else:
result[param_name] = posterior_mean[i]
result[f"sd_{param_name}"] = posterior_sd[i]
return result
@ -530,9 +537,13 @@ def fit_participant_deoptim(
"posterior_mean": posterior_mean,
}
# Ajout des paramètres estimés
# Ajout des paramètres estimés après les avoir renvoyés dans par logis
for i, param_name in enumerate(model_config["param_names"]):
result_dict[param_name] = posterior_mean[i]
# Here we expit alpha, forget
if param_name.startswith("alpha") or param_name.startswith("forget"):
result_dict[param_name] = expit(posterior_mean[i])
else:
result_dict[param_name] = posterior_mean[i]
return result_dict
@ -548,6 +559,7 @@ def fit_all_participants(
method: str = "VBMC",
n_participants: Optional[int] = None,
verbose: bool = True,
plot: bool = True,
) -> Dict[str, List[Dict]]:
"""
Ajuste tous les modèles pour tous les participants.
@ -588,7 +600,7 @@ def fit_all_participants(
try:
if method == "VBMC":
result = fit_participant_pyvbmc(
participant_data, model_config, verbose=False
participant_data, model_config, verbose=False, plot=plot
)
else:
result = fit_participant_deoptim(
@ -849,14 +861,14 @@ if __name__ == "__main__":
# Ajustement de quelques modèles pour test
models_to_fit = [
"HOMOGENEOUS",
"GAIN_LOSS",
"REE_BIASED_SIMPLE",
"REE_BIASED_COMPLEX",
"REE_LEARNING_SIMPLE",
"REE_LEARNING_COMPLEX",
# "HOMOGENEOUS",
# "GAIN_LOSS",
# "REE_BIASED_SIMPLE",
# "REE_BIASED_COMPLEX",
# "REE_LEARNING_SIMPLE",
# "REE_LEARNING_COMPLEX",
"REE_LEARNING_BIASED_SIMPLE",
"REE_LEARNING_BIASED_COMPLEX",
# "REE_LEARNING_BIASED_COMPLEX",
]
# all_results = fit_all_participants_both_methods(
@ -874,6 +886,7 @@ if __name__ == "__main__":
method=method,
n_participants=1, # Set to a number to limit for testing
verbose=True,
plot=True
)
# Comparaison des modèles