Adding plots and sigmoid for relevant params
This commit is contained in:
parent
28efe2e515
commit
8aa1d27675
1 changed files with 26 additions and 13 deletions
|
|
@ -348,7 +348,7 @@ def qlearning_generic(
|
||||||
|
|
||||||
|
|
||||||
def fit_participant_pyvbmc(
|
def fit_participant_pyvbmc(
|
||||||
participant_data: pd.DataFrame, model_config: Dict, verbose: bool = True
|
participant_data: pd.DataFrame, model_config: Dict, verbose: bool = True, plot : bool = True
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Optimise les paramètres du modèle pour un participant utilisant PyVBMC.
|
Optimise les paramètres du modèle pour un participant utilisant PyVBMC.
|
||||||
|
|
@ -399,6 +399,8 @@ def fit_participant_pyvbmc(
|
||||||
options={
|
options={
|
||||||
# "verbose": 0 if not verbose else 1,
|
# "verbose": 0 if not verbose else 1,
|
||||||
"display": "off",
|
"display": "off",
|
||||||
|
"plot": plot,
|
||||||
|
"log_file_name": None,
|
||||||
},
|
},
|
||||||
prior=UniformBox(
|
prior=UniformBox(
|
||||||
a=model_config["lower"], b=model_config["upper"], D=model_config["n_params"]
|
a=model_config["lower"], b=model_config["upper"], D=model_config["n_params"]
|
||||||
|
|
@ -446,6 +448,11 @@ def fit_participant_pyvbmc(
|
||||||
|
|
||||||
# Ajout des paramètres estimés
|
# Ajout des paramètres estimés
|
||||||
for i, param_name in enumerate(model_config["param_names"]):
|
for i, param_name in enumerate(model_config["param_names"]):
|
||||||
|
# Here we use expit for parameters that were originally bounded between 0 and 1
|
||||||
|
if param_name.startswith("alpha") or param_name.startswith("forget"):
|
||||||
|
result[param_name] = expit(posterior_mean[i])
|
||||||
|
result[f"sd_{param_name}"] = expit(posterior_sd[i])
|
||||||
|
else:
|
||||||
result[param_name] = posterior_mean[i]
|
result[param_name] = posterior_mean[i]
|
||||||
result[f"sd_{param_name}"] = posterior_sd[i]
|
result[f"sd_{param_name}"] = posterior_sd[i]
|
||||||
|
|
||||||
|
|
@ -530,8 +537,12 @@ def fit_participant_deoptim(
|
||||||
"posterior_mean": posterior_mean,
|
"posterior_mean": posterior_mean,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Ajout des paramètres estimés
|
# Ajout des paramètres estimés après les avoir renvoyés dans par logis
|
||||||
for i, param_name in enumerate(model_config["param_names"]):
|
for i, param_name in enumerate(model_config["param_names"]):
|
||||||
|
# Here we expit alpha, forget
|
||||||
|
if param_name.startswith("alpha") or param_name.startswith("forget"):
|
||||||
|
result_dict[param_name] = expit(posterior_mean[i])
|
||||||
|
else:
|
||||||
result_dict[param_name] = posterior_mean[i]
|
result_dict[param_name] = posterior_mean[i]
|
||||||
|
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
@ -548,6 +559,7 @@ def fit_all_participants(
|
||||||
method: str = "VBMC",
|
method: str = "VBMC",
|
||||||
n_participants: Optional[int] = None,
|
n_participants: Optional[int] = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
|
plot: bool = True,
|
||||||
) -> Dict[str, List[Dict]]:
|
) -> Dict[str, List[Dict]]:
|
||||||
"""
|
"""
|
||||||
Ajuste tous les modèles pour tous les participants.
|
Ajuste tous les modèles pour tous les participants.
|
||||||
|
|
@ -588,7 +600,7 @@ def fit_all_participants(
|
||||||
try:
|
try:
|
||||||
if method == "VBMC":
|
if method == "VBMC":
|
||||||
result = fit_participant_pyvbmc(
|
result = fit_participant_pyvbmc(
|
||||||
participant_data, model_config, verbose=False
|
participant_data, model_config, verbose=False, plot=plot
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = fit_participant_deoptim(
|
result = fit_participant_deoptim(
|
||||||
|
|
@ -849,14 +861,14 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Ajustement de quelques modèles pour test
|
# Ajustement de quelques modèles pour test
|
||||||
models_to_fit = [
|
models_to_fit = [
|
||||||
"HOMOGENEOUS",
|
# "HOMOGENEOUS",
|
||||||
"GAIN_LOSS",
|
# "GAIN_LOSS",
|
||||||
"REE_BIASED_SIMPLE",
|
# "REE_BIASED_SIMPLE",
|
||||||
"REE_BIASED_COMPLEX",
|
# "REE_BIASED_COMPLEX",
|
||||||
"REE_LEARNING_SIMPLE",
|
# "REE_LEARNING_SIMPLE",
|
||||||
"REE_LEARNING_COMPLEX",
|
# "REE_LEARNING_COMPLEX",
|
||||||
"REE_LEARNING_BIASED_SIMPLE",
|
"REE_LEARNING_BIASED_SIMPLE",
|
||||||
"REE_LEARNING_BIASED_COMPLEX",
|
# "REE_LEARNING_BIASED_COMPLEX",
|
||||||
]
|
]
|
||||||
|
|
||||||
# all_results = fit_all_participants_both_methods(
|
# all_results = fit_all_participants_both_methods(
|
||||||
|
|
@ -874,6 +886,7 @@ if __name__ == "__main__":
|
||||||
method=method,
|
method=method,
|
||||||
n_participants=1, # Set to a number to limit for testing
|
n_participants=1, # Set to a number to limit for testing
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
plot=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Comparaison des modèles
|
# Comparaison des modèles
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue