commit 57b53629c7269dd5773b60ee98a64bd0b24c5b53 Author: Louis Date: Mon Dec 1 21:07:40 2025 +0100 Init of repo diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a3f4f87 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.venv +data +__pycache__ \ No newline at end of file diff --git a/load_data.R b/load_data.R new file mode 100644 index 0000000..87c13bf --- /dev/null +++ b/load_data.R @@ -0,0 +1,20 @@ +library(dplyr) +full_data <- read.csv("data/data_fourchoices.csv") + +# mapping of button names to choice numbers +button_mapping <- c( + "antifragile" = 1, + "robuste" = 2, + "fragile" = 3, + "vulnerable" = 4 +) + +data <- full_data[, c("participant", "click_number", "button_name", "button_value")] %>% + rename(participant_id = participant, trial = click_number, choice = button_name, reward = button_value) %>% + mutate( + option = choice, + choice = button_mapping[choice] + ) %>% + select(participant_id, trial, choice, reward, option) -> data + +write.csv(data, file = "data/prepared_data.csv", row.names = FALSE) diff --git a/load_data.py b/load_data.py new file mode 100644 index 0000000..c4c67dc --- /dev/null +++ b/load_data.py @@ -0,0 +1,22 @@ +# %% +import pandas as pd + +full_data = pd.read_csv("data/data_fourchoices.csv") + +# %% +all_participant_data = full_data[["participant", "button_name", "button_value"]] +# %% +dict_mapping_button_name_to_index = { + "antifragile": 0, + "fragile": 1, + "robuste": 2, + "vulnerable": 3 + } +all_participant_data["choice"] = all_participant_data["button_name"].map(dict_mapping_button_name_to_index) +all_participant_data = all_participant_data.rename(columns={"button_value": "reward"}) + +all_participant_data["rescaled_reward"] = (all_participant_data["reward"]) / 3000 # rescale rewards to [-1,1] +all_participant_data["rescaled_reward"] = all_participant_data["rescaled_reward"].round(4) + +unique_participants = all_participant_data["participant"].unique() +# %% diff --git a/model0.py b/model0.py new file mode 100644 index 0000000..d081af8 --- /dev/null +++ b/model0.py @@ -0,0 +1,116 @@ +# MLE_fit.py +# %% +import numpy as np +import pandas as pd +from scipy.optimize import minimize +from scipy.special import logsumexp + +# %% +def simulate_model(theta, choices=None, rewards=None, n_arms=4, T=None, seed=None): + """Simple simulator to generate choices+rewards if not provided. + theta = (alpha, lam) + """ + alpha, f = theta + rng = np.random.RandomState(seed) + if T is None: + if choices is None: + raise ValueError("Provide T or choices") + T = len(choices) + Q = np.zeros(n_arms) + sim_choices = [] + sim_rewards = [] + # example reward distributions: Bernoulli with different p + ps = np.linspace(0.2, 0.8, n_arms) + for t in range(T): + V = Q + logits = V - logsumexp(V) # numerically stable + probs = np.exp(logits) + a = rng.choice(n_arms, p=probs) + r = rng.binomial(1, ps[a]) # example reward + # updates + Q[a] = Q[a] + alpha * (r - Q[a]) + non_chosen = [i for i in range(n_arms) if i != a] + Q[non_chosen] = (1-f) * Q[non_chosen] + sim_choices.append(a) + sim_rewards.append(r) + return np.array(sim_choices), np.array(sim_rewards) + +def neg_log_likelihood_raw(params, choices, rewards, n_arms=4): + """Params are in unconstrained real space; we'll transform inside the function.""" + # params vector: [logit_alpha, log_lambda, logit_f] + la, lb = params + # transforms + alpha = 1/(1+np.exp(-la)) # sigmoid -> (0,1) + f = 1/(1+np.exp(-lb)) # sigmoid -> (0,1) + T = len(choices) + Q = np.zeros(n_arms) + loglik = 0.0 + for t in range(T): + V = Q + # compute stable log-softmax + logp = V - logsumexp(V) + a = choices[t] + loglik += logp[a] + r = rewards[t] + # update chosen + Q[a] = Q[a] + alpha * (r - Q[a]) + # update non-chosen via forgetting + non_chosen = [i for i in range(n_arms) if i != a] + Q[non_chosen] = (1-f) * Q[non_chosen] # no forgetting in this model + return -loglik # negative log-likelihood for minimization + +def fit_mle(choices, rewards, n_arms=4, n_starts=10): + best = None + for s in range(n_starts): + x0 = np.random.normal(size=2) # la, lb + res = minimize(neg_log_likelihood_raw, x0, + args=(choices, rewards, n_arms), + method='L-BFGS-B', + options={'maxiter':5000}) + if not res.success: + continue + if best is None or res.fun < best['fun']: + best = res + if best is None: + raise RuntimeError("MLE failed") + la, lb = best.x + # transform back + alpha = 1 / (1+np.exp(-la)) + f = 1 / (1 + np.exp(-lb)) + return {'alpha':alpha, 'f':f, 'nll':best.fun, 'opt_result':best} + +# %% +# --- Example usage with synthetic data --- +if __name__ == "__main__": + # generate synthetic participant + true_theta = (0.2, 0.05) # alpha, f + choices, rewards = simulate_model(true_theta, T=200, seed=123) + res = fit_mle(choices, rewards, n_starts=20) + print("True theta:", true_theta) + print("MLE estimate:", res) + +# %% +# Loading participants data +from load_data import all_participant_data, unique_participants + +# %% +import os + +save_dir = "results/model0" +if not os.path.exists(save_dir): + os.makedirs(save_dir) +save_path = os.path.join(save_dir, "mle.csv") + +for pid in unique_participants: + pdata = all_participant_data[all_participant_data["participant"] == pid] + choices = pdata["choice"].values + rewards = pdata["rescaled_reward"].values + res = fit_mle(choices, rewards, n_starts=20) + print(f"Participant {pid} MLE estimate:", res) + pd_res = pd.DataFrame([res]) + pd_res["participant"] = pid + pd_res = pd_res.reindex(columns=["participant", "alpha", "f", "nll"]) + # Save results + pd_res.to_csv(save_path, index=False, mode="a", header=not os.path.exists(save_path)) + +# %% diff --git a/model1_simplest.py b/model1_simplest.py new file mode 100644 index 0000000..b3c8c6b --- /dev/null +++ b/model1_simplest.py @@ -0,0 +1,117 @@ +# MLE_fit.py +# %% +import numpy as np +import pandas as pd +from scipy.optimize import minimize +from scipy.special import logsumexp + +# %% +def simulate_model(theta, choices=None, rewards=None, n_arms=4, T=None, seed=None): + """Simple simulator to generate choices+rewards if not provided. + theta = (alpha, lam, f) + """ + alpha, lam, f = theta + rng = np.random.RandomState(seed) + if T is None: + if choices is None: + raise ValueError("Provide T or choices") + T = len(choices) + Q = np.zeros(n_arms) + sim_choices = [] + sim_rewards = [] + # example reward distributions: Bernoulli with different p + ps = np.linspace(0.2, 0.8, n_arms) + for t in range(T): + V = lam * Q + logits = V - logsumexp(V) # numerically stable + probs = np.exp(logits) + a = rng.choice(n_arms, p=probs) + r = rng.binomial(1, ps[a]) # example reward + # updates + Q[a] = Q[a] + alpha * (r - Q[a]) + non_chosen = [i for i in range(n_arms) if i != a] + Q[non_chosen] = Q[non_chosen] * (1 - f) + sim_choices.append(a) + sim_rewards.append(r) + return np.array(sim_choices), np.array(sim_rewards) + +def neg_log_likelihood_raw(params, choices, rewards, n_arms=4): + """Params are in unconstrained real space; we'll transform inside the function.""" + # params vector: [logit_alpha, log_lambda, logit_f] + la, lb, lc = params + # transforms + alpha = 1/(1+np.exp(-la)) # sigmoid -> (0,1) + lam = np.exp(lb) # positive + f = 1/(1+np.exp(-lc)) # sigmoid -> (0,1) + T = len(choices) + Q = np.zeros(n_arms) + loglik = 0.0 + for t in range(T): + V = lam * Q + # compute stable log-softmax + logp = V - logsumexp(V) + a = choices[t] + loglik += logp[a] + r = rewards[t] + # update chosen + Q[a] = Q[a] + alpha * (r - Q[a]) + # update non-chosen via forgetting + non_chosen = [i for i in range(n_arms) if i != a] + Q[non_chosen] = Q[non_chosen] * (1 - f) + return -loglik # negative log-likelihood for minimization + +def fit_mle(choices, rewards, n_arms=4, n_starts=10): + best = None + for s in range(n_starts): + x0 = np.random.normal(size=3) + res = minimize(neg_log_likelihood_raw, x0, + args=(choices, rewards, n_arms), + method='L-BFGS-B', + options={'maxiter':5000}) + if not res.success: + continue + if best is None or res.fun < best['fun']: + best = res + if best is None: + raise RuntimeError("MLE failed") + la, lb, lc = best.x + # transform back + alpha = 1/(1+np.exp(-la)) + lam = np.exp(lb) + f = 1/(1+np.exp(-lc)) + return {'alpha':alpha, 'lambda':lam, 'f':f, 'nll':best.fun, 'opt_result':best} + +# %% +# --- Example usage with synthetic data --- +if __name__ == "__main__": + # generate synthetic participant + true_theta = (0.2, 2.0, 0.05) # alpha, lambda, f + choices, rewards = simulate_model(true_theta, T=200, seed=123) + res = fit_mle(choices, rewards, n_starts=20) + print("True theta:", true_theta) + print("MLE estimate:", res) + +# %% +# Loading participants data +from load_data import all_participant_data, unique_participants + +# %% +import os +save_dir = "results/model1" +if not os.path.exists(save_dir): + os.makedirs(save_dir) +save_path = os.path.join(save_dir, "mle.csv") + +for pid in unique_participants: + pdata = all_participant_data[all_participant_data["participant"] == pid] + choices = pdata["choice"].values + rewards = pdata["rescaled_reward"].values + res = fit_mle(choices, rewards, n_starts=20) + print(f"Participant {pid} MLE estimate:", res) + pd_res = pd.DataFrame([res]) + pd_res["participant"] = pid + pd_res = pd_res.reindex(columns=["participant", "alpha", "lambda", "f", "nll"]) + # Save results + pd_res.to_csv(save_path, index=False, mode="a", header=not os.path.exists(save_path)) + +# %% diff --git a/model2_loss_gain.py b/model2_loss_gain.py new file mode 100644 index 0000000..015a875 --- /dev/null +++ b/model2_loss_gain.py @@ -0,0 +1,134 @@ +# MLE_fit.py +# %% +import numpy as np +import pandas as pd +from scipy.optimize import minimize +from scipy.special import logsumexp + + +# %% +N_PARAM = 2 # alpha_gain, alpha_loss +def simulate_model(theta, choices=None, rewards=None, n_arms=4, T=None, seed=None): + """Simple simulator to generate choices+rewards if not provided. + theta = (alpha, lam) + """ + alpha_gain, alpha_loss = theta + rng = np.random.RandomState(seed) + if T is None: + if choices is None: + raise ValueError("Provide T or choices") + T = len(choices) + Q = np.zeros(n_arms) + sim_choices = [] + sim_rewards = [] + # example reward distributions: Bernoulli with different p + ps = np.linspace(0.2, 0.8, n_arms) + for t in range(T): + V = Q + logits = V - logsumexp(V) # numerically stable + probs = np.exp(logits) + a = rng.choice(n_arms, p=probs) + r = rng.binomial(1, ps[a]) * (1 if rng.rand() < 0.5 else -1) # example reward with gain/loss + # updates + if r >= 0: + Q[a] = Q[a] + alpha_gain * (r - Q[a]) + else: + Q[a] = Q[a] + alpha_loss * (r - Q[a]) + non_chosen = [i for i in range(n_arms) if i != a] + Q[non_chosen] = Q[non_chosen] + sim_choices.append(a) + sim_rewards.append(r) + return np.array(sim_choices), np.array(sim_rewards) + + +def neg_log_likelihood_raw(params, choices, rewards, n_arms=4): + """Params are in unconstrained real space; we'll transform inside the function.""" + # params vector: [logit_alpha_gain, logit_alpha_loss] + lag, lal = params + # transforms + alpha_gain = 1 / (1 + np.exp(-lag)) # sigmoid -> (0,1) + alpha_loss = 1 / (1 + np.exp(-lal)) # sigmoid -> (0,1) + T = len(choices) + Q = np.zeros(n_arms) + loglik = 0.0 + for t in range(T): + V = Q + # compute stable log-softmax + logp = V - logsumexp(V) + a = choices[t] + loglik += logp[a] + r = rewards[t] + # update chosen + if r >= 0: + Q[a] = Q[a] + alpha_gain * (r - Q[a]) + else: + Q[a] = Q[a] + alpha_loss * (r - Q[a]) + # update non-chosen via forgetting + non_chosen = [i for i in range(n_arms) if i != a] + Q[non_chosen] = Q[non_chosen] # no forgetting in this model + return -loglik # negative log-likelihood for minimization + + +def fit_mle(choices, rewards, n_arms=4, x0=None, n_starts=10): + best = None + for s in range(n_starts): + if x0 is None: + x0 = np.random.normal(size=N_PARAM) # lag, lal + res = minimize( + neg_log_likelihood_raw, + x0, + args=(choices, rewards, n_arms), + method="L-BFGS-B", + options={"maxiter": 5000}, + ) + if not res.success: + continue + if best is None or res.fun < best["fun"]: + best = res + if best is None: + raise RuntimeError("MLE failed") + lag, lal = best.x + # transform back + alpha_gain = 1 / (1 + np.exp(-lag)) + alpha_loss = 1 / (1 + np.exp(-lal)) + return {"alpha_gain": alpha_gain, "alpha_loss": alpha_loss, "nll": best.fun, "opt_result": best} + +# %% +# --- Example usage with synthetic data --- +if __name__ == "__main__": + np.random.seed(42) + # generate synthetic participant + true_theta = (0.2, 0.9) # alpha_gain, alpha_loss + choices, rewards = simulate_model(true_theta, T=200, seed=123) + res = fit_mle(choices, rewards, n_starts=20) + print("True theta:", true_theta) + print("MLE estimate:", res) + +# %% +# Loading participants data +from load_data import all_participant_data, unique_participants + +# %% +import os + +save_dir = "results/model2" +if not os.path.exists(save_dir): + os.makedirs(save_dir) +save_path = os.path.join(save_dir, "mle.csv") + +for pid in unique_participants: + pdata = all_participant_data[all_participant_data["participant"] == pid] + choices = pdata["choice"].values + rewards = pdata["rescaled_reward"].values + res = fit_mle(choices, rewards, x0=np.array([1.0, 1.0]), n_starts=20) + print(f"Participant {pid} MLE estimate:", res) + pd_res = pd.DataFrame([res]) + pd_res["participant"] = pid + pd_res = pd_res.reindex(columns=["participant", "alpha_gain", "alpha_loss", "nll"]) + pd_res["model"] = "model2_loss_gain" + # Save results + pd_res.to_csv( + save_path, index=False, mode="a", header=not os.path.exists(save_path) + ) + +# %% diff --git a/modelling R + PyVBMC V5.R b/modelling R + PyVBMC V5.R new file mode 100644 index 0000000..97076e5 --- /dev/null +++ b/modelling R + PyVBMC V5.R @@ -0,0 +1,932 @@ +# ============================================================================ +# MODÈLES Q-LEARNING EMBOÎTÉS POUR DÉCISION AVEC ÉVÉNEMENTS RARES +# ============================================================================ + +library(tidyverse) +library(DEoptim) +library(numDeriv) +library(reticulate) # Pour interfacer avec PyVBMC + +# Configuration optionnelle de l'environnement Python +# use_python("/usr/bin/python3") # Ajuster selon votre installation +# py_install("pyvbmc") # Installer PyVBMC si nécessaire + +# ============================================================================ +# FONCTION GÉNÉRIQUE DE Q-LEARNING +# ============================================================================ + +qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { + + # Conversion des choix en indices numériques + if (is.factor(data$choice) || is.character(data$choice)) { + choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable") + data$choice_idx <- match(as.character(data$choice), choice_levels) + } else { + data$choice_idx <- data$choice + } + + n_arms <- 4 + n_trials <- nrow(data) + + # Extraction des paramètres selon la configuration du modèle + param_idx <- 1 + + # ALPHA(S) + if (model_config$n_alpha == 1) { + alpha_loss <- alpha_gain <- alpha_BS <- alpha_JP <- plogis(params[param_idx]) + param_idx <- param_idx + 1 + } else if (model_config$n_alpha == 2) { + alpha_loss <- plogis(params[param_idx]) + alpha_gain <- plogis(params[param_idx + 1]) + alpha_BS <- alpha_loss + alpha_JP <- alpha_gain + param_idx <- param_idx + 2 + } else if (model_config$n_alpha == 4) { + alpha_loss <- plogis(params[param_idx]) + alpha_gain <- plogis(params[param_idx + 1]) + alpha_BS <- plogis(params[param_idx + 2]) + alpha_JP <- plogis(params[param_idx + 3]) + param_idx <- param_idx + 4 + } + + # FORGET(S) + if (model_config$n_forget == 1) { + forget <- rep(plogis(params[param_idx]), n_arms) + param_idx <- param_idx + 1 + } else if (model_config$n_forget == 4) { + forget <- plogis(params[param_idx:(param_idx + 3)]) + param_idx <- param_idx + 4 + } + + # LAMBDA(S) + if (model_config$n_lambda == 1) { + lambda <- rep(exp(params[param_idx]), n_arms) + param_idx <- param_idx + 1 + } else if (model_config$n_lambda == 4) { + lambda <- exp(params[param_idx:(param_idx + 3)]) + param_idx <- param_idx + 4 + } + + # RHO(S) - Biais pour événements rares + if (model_config$has_rho) { + rho_BS <- params[param_idx] # BS avoidance + rho_JP <- params[param_idx + 1] # JP seeking + param_idx <- param_idx + 2 + } else { + rho_BS <- rho_JP <- 0 + } + + # Initialisation des Q-values + Q <- rep(0, n_arms) + log_lik <- 0 + + for (t in 1:n_trials) { + choice <- data$choice_idx[t] + reward <- data$reward[t] + + # Calcul des valeurs subjectives V(t) + V <- lambda * Q + + # Ajout des biais pour événements rares si le modèle le permet + if (model_config$has_rho) { + # Identification des options susceptibles de produire BS/JP + # antifragile (1) = JP possible, fragile (2) = BS possible + # vulnerable (4) = BS et JP possibles + V[1] <- V[1] + rho_JP # antifragile + V[2] <- V[2] + rho_BS # fragile + V[4] <- V[4] + rho_BS + rho_JP # vulnerable + } + + # Softmax + V_max <- max(V) + exp_V <- exp(V - V_max) + probs <- exp_V / sum(exp_V) + probs <- pmax(probs, 1e-10) + probs <- probs / sum(probs) + + # Log-likelihood + log_lik <- log_lik + log(probs[choice]) + + # Mise à jour Q-learning + Q_new <- Q + + # Choix de l'alpha approprié + if (reward == -3000) { + alpha_used <- alpha_BS + } else if (reward == 3000) { + alpha_used <- alpha_JP + } else if (reward < 0) { + alpha_used <- alpha_loss + } else { + alpha_used <- alpha_gain + } + + # Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t)) + Q_new[choice] <- Q[choice] + alpha_used * (reward - Q[choice]) + + # Options non choisies : Q(t+1) = Q(t) * (1 - f) + not_chosen <- setdiff(1:n_arms, choice) + Q_new[not_chosen] <- Q[not_chosen] * (1 - forget[not_chosen]) + + Q <- Q_new + } + + if (return_negLL) { + return(-log_lik) + } else { + return(log_lik) + } +} + +# ============================================================================ +# CONFIGURATIONS DES MODÈLES EMBOÎTÉS +# ============================================================================ + +get_model_configs <- function() { + list( + HOMOGENEOUS = list( + name = "HOMOGENEOUS", + n_alpha = 1, + n_forget = 1, + n_lambda = 1, + has_rho = FALSE, + n_params = 3, + param_names = c("alpha", "forget", "lambda"), + lower = c(-5, -5, -3), + upper = c(5, 5, 3) + ), + + GAIN_LOSS = list( + name = "GAIN_LOSS", + n_alpha = 2, + n_forget = 1, + n_lambda = 1, + has_rho = FALSE, + n_params = 4, + param_names = c("alpha_loss", "alpha_gain", "forget", "lambda"), + lower = c(-5, -5, -5, -3), + upper = c(5, 5, 5, 3) + ), + + BIASED = list( + name = "BIASED", + n_alpha = 2, + n_forget = 4, + n_lambda = 4, + has_rho = FALSE, + n_params = 10, + param_names = c("alpha_loss", "alpha_gain", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4"), + lower = c(-5, -5, rep(-5, 4), rep(-3, 4)), + upper = c(5, 5, rep(5, 4), rep(3, 4)) + ), + + REE_BIASED_SIMPLE = list( + name = "REE_BIASED_SIMPLE", + n_alpha = 2, + n_forget = 1, + n_lambda = 1, + has_rho = TRUE, + n_params = 6, + param_names = c("alpha_loss", "alpha_gain", "forget", "lambda", + "rho_BS", "rho_JP"), + lower = c(-5, -5, -5, -3, -10, -10), + upper = c(5, 5, 5, 3, 10, 10) + ), + + REE_BIASED_COMPLEX = list( + name = "REE_BIASED_COMPLEX", + n_alpha = 2, + n_forget = 4, + n_lambda = 4, + has_rho = TRUE, + n_params = 12, + param_names = c("alpha_loss", "alpha_gain", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4", + "rho_BS", "rho_JP"), + lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10), + upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10) + ), + + REE_LEARNING_SIMPLE = list( + name = "REE_LEARNING_SIMPLE", + n_alpha = 4, + n_forget = 1, + n_lambda = 1, + has_rho = FALSE, + n_params = 6, + param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget", "lambda"), + lower = c(-5, -5, -5, -5, -5, -3), + upper = c(5, 5, 5, 5, 5, 3) + ), + + REE_LEARNING_COMPLEX = list( + name = "REE_LEARNING_COMPLEX", + n_alpha = 4, + n_forget = 4, + n_lambda = 4, + has_rho = FALSE, + n_params = 12, + param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4"), + lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)), + upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4)) + ), + + REE_LEARNING_BIASED_SIMPLE = list( + name = "REE_LEARNING_BIASED_SIMPLE", + n_alpha = 4, + n_forget = 1, + n_lambda = 1, + has_rho = TRUE, + n_params = 8, + param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget", "lambda", "rho_BS", "rho_JP"), + lower = c(-5, -5, -5, -5, -5, -3, -10, -10), + upper = c(5, 5, 5, 5, 5, 3, 10, 10) + ), + + REE_LEARNING_BIASED_COMPLEX = list( + name = "REE_LEARNING_BIASED_COMPLEX", + n_alpha = 4, + n_forget = 4, + n_lambda = 4, + has_rho = TRUE, + n_params = 14, + param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4", + "rho_BS", "rho_JP"), + lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10), + upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10) + ) + ) +} + +# ============================================================================ +# ESTIMATION AVEC VBMC (MÉTHODE BAYÉSIENNE) +# ============================================================================ + +fit_participant_vbmc <- function(participant_data, model_config, use_python_vbmc = TRUE) { + + if (!use_python_vbmc) { + warning("VBMC nécessite Python. Utilisation de DEoptim à la place.") + return(fit_participant(participant_data, model_config)) + } + + # Import de PyVBMC + tryCatch({ + pyvbmc <- import("pyvbmc") + }, error = function(e) { + stop("PyVBMC n'est pas installé. Installez-le avec: py_install('pyvbmc')") + }) + + # Wrapper pour la fonction log-posterior + log_posterior_wrapper <- function(params_array) { + # PyVBMC passe un array numpy, on le convertit en vecteur R + params_vec <- as.vector(params_array) + + # Retourne la log-vraisemblance (négatif de negLL) + negLL <- qlearning_generic(params_vec, participant_data, model_config, return_negLL = TRUE) + return(-negLL) # VBMC maximise, donc on retourne -negLL + } + + # Point de départ (milieu des bornes plausibles) + x0 <- (model_config$lower + model_config$upper) / 2 + + # Bornes plausibles (25%-75% de la plage) + plb <- model_config$lower + 0.25 * (model_config$upper - model_config$lower) + pub <- model_config$upper - 0.25 * (model_config$upper - model_config$lower) + + # Initialisation de VBMC + vbmc_obj <- pyvbmc$VBMC( + log_density = log_posterior_wrapper, + x0 = x0, + lower_bounds = model_config$lower, + upper_bounds = model_config$upper, + plausible_lower_bounds = plb, + plausible_upper_bounds = pub + ) + + # Optimisation + vbmc_result <- vbmc_obj$optimize() + vp <- vbmc_result[[1]] # Variational posterior + results_dict <- vbmc_result[[2]] # Résultats + + # Extraction des statistiques + posterior_stats <- vp$moments() + posterior_mean <- as.vector(posterior_stats[[1]]) + posterior_sd <- sqrt(diag(posterior_stats[[2]])) + + elbo <- results_dict$elbo + elbo_sd <- results_dict$elbo_sd + + # Calcul du BIC avec la posterior mean + negLL <- qlearning_generic(posterior_mean, participant_data, model_config, return_negLL = TRUE) + + # Création du tibble de résultats + result_df <- tibble( + model = model_config$name, + n_params = model_config$n_params, + negLL = negLL, + ELBO = elbo, + ELBO_SD = elbo_sd, + AIC = 2 * negLL + 2 * model_config$n_params, + BIC = 2 * negLL + model_config$n_params * log(nrow(participant_data)), + method = "VBMC", + n_iterations = results_dict$iterations, + converged = TRUE # VBMC a ses propres critères + ) + + # Ajout des paramètres (posterior mean et SD) + for (i in 1:model_config$n_params) { + result_df[[model_config$param_names[i]]] <- posterior_mean[i] + result_df[[paste0("sd_", model_config$param_names[i])]] <- posterior_sd[i] + } + + # Stockage de la posterior complète pour visualisations + result_df$vp <- list(vp) + result_df$vbmc_results <- list(results_dict) + + return(result_df) +} + +# ============================================================================ +# VISUALISATIONS DE CONVERGENCE AVANCÉES +# ============================================================================ + +plot_convergence_detailed <- function(fit_results, participant_id = NULL) { + require(ggplot2) + require(patchwork) + + if (!is.null(participant_id)) { + fit_results <- fit_results %>% filter(participant_id == !!participant_id) + } + + plots <- list() + + # 1. Trace de convergence (negLL) + if ("convergence_range" %in% names(fit_results)) { + p1 <- fit_results %>% + mutate(participant_id = factor(participant_id)) %>% + ggplot(aes(x = participant_id, y = convergence_range, fill = converged)) + + geom_col() + + scale_y_log10() + + geom_hline(yintercept = 1, linetype = "dashed", color = "red") + + coord_flip() + + theme_minimal() + + labs(title = "Range de convergence par participant", + subtitle = "Seuil = 1.0 (ligne rouge)", + y = "Range negLL (log scale)") + plots$convergence_range <- p1 + } + + # 2. Distribution des paramètres estimés + param_cols <- grep("^alpha|^forget|^lambda|^rho", names(fit_results), value = TRUE) + param_cols <- setdiff(param_cols, grep("^se_|^sd_", param_cols, value = TRUE)) + + if (length(param_cols) > 0) { + p2 <- fit_results %>% + select(participant_id, all_of(param_cols)) %>% + pivot_longer(-participant_id, names_to = "parameter", values_to = "value") %>% + ggplot(aes(x = value, fill = parameter)) + + geom_histogram(bins = 30, alpha = 0.7) + + facet_wrap(~parameter, scales = "free", ncol = 3) + + theme_minimal() + + theme(legend.position = "none") + + labs(title = "Distribution des paramètres estimés", + x = "Valeur", y = "Fréquence") + plots$param_distribution <- p2 + } + + # 3. Corrélation paramètres vs convergence + if ("convergence_range" %in% names(fit_results) && length(param_cols) > 0) { + # Sélectionner quelques paramètres clés + key_params <- param_cols[1:min(4, length(param_cols))] + + p3 <- fit_results %>% + select(convergence_range, all_of(key_params)) %>% + pivot_longer(-convergence_range, names_to = "parameter", values_to = "value") %>% + ggplot(aes(x = value, y = convergence_range)) + + geom_point(alpha = 0.5) + + geom_smooth(method = "loess", se = FALSE, color = "red") + + scale_y_log10() + + facet_wrap(~parameter, scales = "free_x") + + theme_minimal() + + labs(title = "Convergence vs valeur des paramètres", + y = "Range convergence (log scale)") + plots$param_vs_convergence <- p3 + } + + # 4. Heatmap Hessienne (si disponible) + if ("hessian_positive_definite" %in% names(fit_results)) { + p4 <- fit_results %>% + mutate(participant_id = factor(participant_id)) %>% + ggplot(aes(x = participant_id, y = 1, fill = hessian_positive_definite)) + + geom_tile() + + scale_fill_manual(values = c("FALSE" = "red", "TRUE" = "green"), + na.value = "grey") + + coord_flip() + + theme_minimal() + + theme(axis.text.x = element_blank(), + axis.ticks.x = element_blank()) + + labs(title = "Qualité de la Hessienne", + subtitle = "Vert = définie positive, Rouge = problème", + fill = "Hessienne OK", x = "Participant", y = "") + plots$hessian_quality <- p4 + } + + # 5. Incertitude des paramètres (erreurs standard) + se_cols <- grep("^se_|^sd_", names(fit_results), value = TRUE) + if (length(se_cols) > 0) { + p5 <- fit_results %>% + select(participant_id, all_of(se_cols)) %>% + pivot_longer(-participant_id, names_to = "parameter", values_to = "se") %>% + mutate(parameter = str_remove(parameter, "^se_|^sd_")) %>% + ggplot(aes(x = se, fill = parameter)) + + geom_histogram(bins = 30, alpha = 0.7) + + facet_wrap(~parameter, scales = "free", ncol = 3) + + theme_minimal() + + theme(legend.position = "none") + + labs(title = "Distribution des erreurs standard", + subtitle = "Plus petit = meilleure précision", + x = "Erreur standard", y = "Fréquence") + plots$se_distribution <- p5 + } + + return(plots) +} + +# Visualisation spécifique VBMC (posteriors bayésiennes) +plot_vbmc_diagnostics <- function(vbmc_fit_results, participant_id) { + require(ggplot2) + require(patchwork) + + participant_data <- vbmc_fit_results %>% + filter(participant_id == !!participant_id) + + if (nrow(participant_data) == 0) { + stop("Participant non trouvé") + } + + if (!"vp" %in% names(participant_data)) { + stop("Pas de résultats VBMC disponibles (vp manquant)") + } + + vp <- participant_data$vp[[1]] + vbmc_results <- participant_data$vbmc_results[[1]] + + # Échantillonnage de la posterior + n_samples <- 10000 + samples <- vp$sample(n_samples) + + # Conversion en dataframe + param_names <- vbmc_fit_results %>% + select(starts_with("alpha"), starts_with("forget"), + starts_with("lambda"), starts_with("rho")) %>% + select(-starts_with("se_"), -starts_with("sd_")) %>% + names() + + samples_df <- as.data.frame(samples) + names(samples_df) <- param_names + + plots <- list() + + # 1. Marginal posteriors + p1 <- samples_df %>% + pivot_longer(everything(), names_to = "parameter", values_to = "value") %>% + ggplot(aes(x = value)) + + geom_histogram(aes(y = after_stat(density)), bins = 50, + fill = "steelblue", alpha = 0.7) + + geom_density(color = "red", linewidth = 1) + + facet_wrap(~parameter, scales = "free", ncol = 3) + + theme_minimal() + + labs(title = paste("Posterior distributions - Participant", participant_id), + subtitle = "Histogramme + densité estimée", + x = "Valeur", y = "Densité") + plots$marginal_posteriors <- p1 + + # 2. Pairwise correlations (pour les 4 premiers paramètres) + if (length(param_names) >= 2) { + n_plot <- min(4, length(param_names)) + pairs_data <- samples_df[, 1:n_plot] + + p2 <- GGally::ggpairs(pairs_data, + lower = list(continuous = "points"), + diag = list(continuous = "densityDiag"), + upper = list(continuous = "cor")) + + theme_minimal() + + labs(title = "Corrélations entre paramètres") + plots$pairwise_correlations <- p2 + } + + # 3. Trace de l'ELBO + if (!is.null(vbmc_results$elbo_trace)) { + elbo_trace <- vbmc_results$elbo_trace + p3 <- tibble( + iteration = seq_along(elbo_trace), + ELBO = elbo_trace + ) %>% + ggplot(aes(x = iteration, y = ELBO)) + + geom_line(color = "darkblue", linewidth = 1) + + geom_point(size = 2, alpha = 0.5) + + theme_minimal() + + labs(title = "Convergence de l'ELBO", + subtitle = "Evidence Lower Bound", + x = "Itération", y = "ELBO") + plots$elbo_trace <- p3 + } + + # 4. Incertitude des paramètres (credible intervals) + posterior_mean <- colMeans(samples_df) + posterior_sd <- apply(samples_df, 2, sd) + ci_lower <- apply(samples_df, 2, quantile, probs = 0.025) + ci_upper <- apply(samples_df, 2, quantile, probs = 0.975) + + p4 <- tibble( + parameter = param_names, + mean = posterior_mean, + sd = posterior_sd, + ci_lower = ci_lower, + ci_upper = ci_upper + ) %>% + mutate(parameter = fct_reorder(parameter, mean)) %>% + ggplot(aes(x = parameter, y = mean)) + + geom_point(size = 3) + + geom_errorbar(aes(ymin = ci_lower, ymax = ci_upper), width = 0.2) + + coord_flip() + + theme_minimal() + + labs(title = "Estimations postérieures avec intervalles de crédibilité à 95%", + x = "Paramètre", y = "Valeur") + plots$credible_intervals <- p4 + + return(plots) +} + +# Fonction pour comparer DEoptim vs VBMC +compare_optimization_methods <- function(data, participant_ids = NULL, + model_config, n_participants = 5) { + + if (is.null(participant_ids)) { + participant_ids <- sample(unique(data$participant_id), + min(n_participants, length(unique(data$participant_id)))) + } + + comparison_results <- map_df(participant_ids, function(pid) { + cat("Participant", pid, "\n") + + participant_data <- data %>% + filter(participant_id == pid) %>% + arrange(trial) + + # Méthode 1: DEoptim + fit_deoptim <- fit_participant(participant_data, model_config, n_runs = 5) + + # Méthode 2: VBMC + fit_vbmc <- tryCatch({ + fit_participant_vbmc(participant_data, model_config, use_python_vbmc = TRUE) + }, error = function(e) { + cat(" VBMC échoué pour participant", pid, ":", e$message, "\n") + return(NULL) + }) + + if (!is.null(fit_vbmc)) { + bind_rows( + fit_deoptim %>% mutate(method = "DEoptim", participant_id = pid), + fit_vbmc %>% mutate(participant_id = pid) + ) + } else { + fit_deoptim %>% mutate(method = "DEoptim", participant_id = pid) + } + }) + + # Visualisation de la comparaison + p <- comparison_results %>% + select(participant_id, method, negLL, BIC) %>% + pivot_longer(c(negLL, BIC), names_to = "metric", values_to = "value") %>% + ggplot(aes(x = method, y = value, fill = method)) + + geom_boxplot() + + facet_wrap(~metric, scales = "free_y") + + theme_minimal() + + labs(title = "Comparaison DEoptim vs VBMC", + subtitle = paste("N =", length(unique(comparison_results$participant_id)), "participants"), + y = "Valeur") + + print(p) + + return(comparison_results) + + + all_results <- vector("list", n_runs) + + for (run in 1:n_runs) { + set.seed(1000 * as.numeric(factor(model_config$name)) + run) + + result <- DEoptim( + fn = qlearning_generic, + lower = model_config$lower, + upper = model_config$upper, + data = participant_data, + model_config = model_config, + control = DEoptim.control( + itermax = 200, + trace = FALSE, + parallelType = 0, + NP = max(50, model_config$n_params * 10) + ) + ) + + all_results[[run]] <- list( + params = result$optim$bestmem, + negLL = result$optim$bestval + ) + } + + # Sélection du meilleur run + all_negLL <- sapply(all_results, function(x) x$negLL) + best_run <- which.min(all_negLL) + + negLL_sd <- sd(all_negLL) + negLL_range <- max(all_negLL) - min(all_negLL) + + params <- all_results[[best_run]]$params + negLL <- all_results[[best_run]]$negLL + + # Calcul de la Hessienne + hessian_result <- tryCatch({ + numDeriv::hessian(qlearning_generic, params, + data = participant_data, + model_config = model_config) + }, error = function(e) NULL) + + hessian_positive_definite <- FALSE + param_se <- rep(NA, model_config$n_params) + + if (!is.null(hessian_result)) { + eigenvalues <- eigen(hessian_result, only.values = TRUE)$values + hessian_positive_definite <- all(eigenvalues > 0) + + if (hessian_positive_definite) { + param_vcov <- tryCatch(solve(hessian_result), error = function(e) NULL) + if (!is.null(param_vcov)) { + param_se <- sqrt(diag(param_vcov)) + } + } + } + + # Création du tibble de résultats + result_df <- tibble( + model = model_config$name, + n_params = model_config$n_params, + negLL = negLL, + AIC = 2 * negLL + 2 * model_config$n_params, + BIC = 2 * negLL + model_config$n_params * log(nrow(participant_data)), + convergence_sd = negLL_sd, + convergence_range = negLL_range, + hessian_positive_definite = hessian_positive_definite, + converged = negLL_range < 1 + ) + + # Ajout des paramètres estimés + for (i in 1:model_config$n_params) { + result_df[[model_config$param_names[i]]] <- params[i] + result_df[[paste0("se_", model_config$param_names[i])]] <- param_se[i] + } + + return(result_df) +} + +# ============================================================================ +# ESTIMATION POUR TOUS LES PARTICIPANTS +# ============================================================================ + +fit_all_participants_all_models <- function(data, models_to_fit = NULL) { + + model_configs <- get_model_configs() + + if (!is.null(models_to_fit)) { + model_configs <- model_configs[models_to_fit] + } + + participants <- unique(data$participant_id) + + all_results <- list() + + for (model_name in names(model_configs)) { + cat("\n=== Fitting model:", model_name, "===\n") + + model_config <- model_configs[[model_name]] + + model_results <- map_df(participants, function(pid) { + cat(" Participant", pid, "\n") + + participant_data <- data %>% + filter(participant_id == pid) %>% + arrange(trial) + + fit <- fit_participant(participant_data, model_config) + fit %>% mutate(participant_id = pid, .before = 1) + }) + + all_results[[model_name]] <- model_results + } + + return(all_results) +} + +# ============================================================================ +# COMPARAISON DES MODÈLES EMBOÎTÉS +# ============================================================================ + +compare_nested_models <- function(all_results) { + + # Comparaison globale + global_comparison <- map_df(names(all_results), function(model_name) { + results <- all_results[[model_name]] + + tibble( + model = model_name, + n_params = unique(results$n_params), + n_converged = sum(results$converged), + mean_negLL = mean(results$negLL), + total_negLL = sum(results$negLL), + total_AIC = sum(results$AIC), + total_BIC = sum(results$BIC), + mean_convergence_range = mean(results$convergence_range) + ) + }) %>% + arrange(total_BIC) + + cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n") + print(global_comparison) + + # Meilleur modèle par participant (BIC) + best_models_per_participant <- map_df(names(all_results), function(model_name) { + all_results[[model_name]] %>% + select(participant_id, model, BIC, converged) + }) %>% + group_by(participant_id) %>% + slice_min(BIC, n = 1) %>% + ungroup() + + cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n") + print(table(best_models_per_participant$model)) + + # Comparaison par paires de modèles emboîtés + cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n") + + # Exemples de paires emboîtées + nested_pairs <- list( + c("HOMOGENEOUS", "GAIN_LOSS"), + c("GAIN_LOSS", "BIASED"), + c("GAIN_LOSS", "REE_BIASED_SIMPLE"), + c("REE_BIASED_SIMPLE", "REE_BIASED_COMPLEX"), + c("GAIN_LOSS", "REE_LEARNING_SIMPLE"), + c("REE_LEARNING_SIMPLE", "REE_LEARNING_COMPLEX"), + c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"), + c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX") + ) + + lrt_results <- map_df(nested_pairs, function(pair) { + simple_model <- pair[1] + complex_model <- pair[2] + + if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) { + simple_res <- all_results[[simple_model]] + complex_res <- all_results[[complex_model]] + + # LRT par participant + lrt_df <- simple_res %>% + select(participant_id, negLL_simple = negLL, converged_simple = converged) %>% + left_join( + complex_res %>% select(participant_id, negLL_complex = negLL, converged_complex = converged), + by = "participant_id" + ) %>% + filter(converged_simple, converged_complex) %>% + mutate( + LR_stat = 2 * (negLL_simple - negLL_complex), + df_diff = unique(complex_res$n_params) - unique(simple_res$n_params), + p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE), + significant = p_value < 0.05 + ) + + tibble( + simple_model = simple_model, + complex_model = complex_model, + n_participants = nrow(lrt_df), + pct_significant = mean(lrt_df$significant) * 100, + mean_LR = mean(lrt_df$LR_stat), + median_p = median(lrt_df$p_value) + ) + } + }) + + print(lrt_results) + + list( + global_comparison = global_comparison, + best_models_per_participant = best_models_per_participant, + lrt_results = lrt_results, + all_results = all_results + ) +} + +# ============================================================================ +# VISUALISATION +# ============================================================================ + +plot_model_comparison <- function(comparison_results) { + require(ggplot2) + require(patchwork) + + # 1. Comparaison BIC globale + p1 <- comparison_results$global_comparison %>% + mutate(model = fct_reorder(model, total_BIC)) %>% + ggplot(aes(x = model, y = total_BIC, fill = model)) + + geom_col() + + coord_flip() + + theme_minimal() + + labs(title = "Comparaison globale des modèles (BIC)", + subtitle = "Plus bas = meilleur") + + theme(legend.position = "none") + + # 2. Meilleur modèle par participant + p2 <- comparison_results$best_models_per_participant %>% + count(model) %>% + mutate(model = fct_reorder(model, n)) %>% + ggplot(aes(x = model, y = n, fill = model)) + + geom_col() + + coord_flip() + + theme_minimal() + + labs(title = "Meilleur modèle par participant", + y = "Nombre de participants") + + theme(legend.position = "none") + + # 3. Tests LRT + if (nrow(comparison_results$lrt_results) > 0) { + p3 <- comparison_results$lrt_results %>% + mutate(comparison = paste(simple_model, "→", complex_model)) %>% + ggplot(aes(x = fct_reorder(comparison, pct_significant), + y = pct_significant)) + + geom_col(fill = "steelblue") + + geom_hline(yintercept = 50, linetype = "dashed", color = "red") + + coord_flip() + + theme_minimal() + + labs(title = "Tests LRT entre modèles emboîtés", + y = "% participants avec p < 0.05", + x = "Comparaison") + } else { + p3 <- ggplot() + theme_void() + } + + # 4. Convergence par modèle + p4 <- map_df(names(comparison_results$all_results), function(model_name) { + comparison_results$all_results[[model_name]] %>% + select(model, convergence_range, converged) + }) %>% + ggplot(aes(x = model, y = convergence_range, fill = converged)) + + geom_boxplot() + + scale_y_log10() + + coord_flip() + + theme_minimal() + + labs(title = "Convergence par modèle", + y = "Range negLL (log scale)") + + (p1 | p2) / (p3 | p4) +} + +# ============================================================================ +# EXEMPLE D'UTILISATION +# ============================================================================ + +# Charger vos données +# data <- read_csv("votre_fichier.csv") +# Colonnes requises: participant_id, trial, choice, reward + +# Estimation de tous les modèles +# all_results <- fit_all_participants_all_models(data) + +# Ou seulement certains modèles +# all_results <- fit_all_participants_all_models( +# data, +# models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE", +# "REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE") +# ) + +# Comparaison +# comparison <- compare_nested_models(all_results) + +# Visualisation +# plot_model_comparison(comparison) + +# Sauvegarder les résultats +# for (model_name in names(all_results)) { +# write_csv(all_results[[model_name]], +# paste0("results_", model_name, ".csv")) +# } +# write_csv(comparison$global_comparison, "global_comparison.csv") +# write_csv(comparison$best_models_per_participant, "best_models.csv") diff --git a/modelling V4.R b/modelling V4.R new file mode 100644 index 0000000..3901990 --- /dev/null +++ b/modelling V4.R @@ -0,0 +1,579 @@ +# ============================================================================ +# MODÈLES Q-LEARNING EMBOÎTÉS POUR DÉCISION AVEC ÉVÉNEMENTS RARES +# ============================================================================ + +library(tidyverse) +library(DEoptim) +library(numDeriv) + +# ============================================================================ +# FONCTION GÉNÉRIQUE DE Q-LEARNING +# ============================================================================ + +qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { + + # Conversion des choix en indices numériques + if (is.factor(data$button_name) || is.character(data$button_name)) { + choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable") + data$choice_idx <- match(as.character(data$button_name), choice_levels) + } else { + data$choice_idx <- data$button_name + } + + n_arms <- 4 + n_trials <- nrow(data) + + # Extraction des paramètres selon la configuration du modèle + param_idx <- 1 + + # ALPHA(S) + if (model_config$n_alpha == 1) { + alpha_loss <- alpha_gain <- alpha_BS <- alpha_JP <- plogis(params[param_idx]) + param_idx <- param_idx + 1 + } else if (model_config$n_alpha == 2) { + alpha_loss <- plogis(params[param_idx]) + alpha_gain <- plogis(params[param_idx + 1]) + alpha_BS <- alpha_loss + alpha_JP <- alpha_gain + param_idx <- param_idx + 2 + } else if (model_config$n_alpha == 4) { + alpha_loss <- plogis(params[param_idx]) + alpha_gain <- plogis(params[param_idx + 1]) + alpha_BS <- plogis(params[param_idx + 2]) + alpha_JP <- plogis(params[param_idx + 3]) + param_idx <- param_idx + 4 + } + + # FORGET(S) + if (model_config$n_forget == 1) { + forget <- rep(plogis(params[param_idx]), n_arms) + param_idx <- param_idx + 1 + } else if (model_config$n_forget == 4) { + forget <- plogis(params[param_idx:(param_idx + 3)]) + param_idx <- param_idx + 4 + } + + # LAMBDA(S) + if (model_config$n_lambda == 1) { + lambda <- rep(exp(params[param_idx]), n_arms) + param_idx <- param_idx + 1 + } else if (model_config$n_lambda == 4) { + lambda <- exp(params[param_idx:(param_idx + 3)]) + param_idx <- param_idx + 4 + } + + # RHO(S) - Biais pour événements rares + if (model_config$has_rho) { + rho_BS <- params[param_idx] # BS avoidance + rho_JP <- params[param_idx + 1] # JP seeking + param_idx <- param_idx + 2 + } else { + rho_BS <- rho_JP <- 0 + } + + # Initialisation des Q-values + Q <- rep(0, n_arms) + log_lik <- 0 + + for (t in 1:n_trials) { + choice <- data$choice_idx[t] + reward <- data$button_value[t] + + # Calcul des valeurs subjectives V(t) + V <- lambda * Q + + # Ajout des biais pour événements rares si le modèle le permet + if (model_config$has_rho) { + # Identification des options susceptibles de produire BS/JP + # antifragile (1) = JP possible, fragile (2) = BS possible + # vulnerable (4) = BS et JP possibles + V[1] <- V[1] + rho_JP # antifragile + V[2] <- V[2] + rho_BS # fragile + V[4] <- V[4] + rho_BS + rho_JP # vulnerable + } + + # Softmax + V_max <- max(V) + exp_V <- exp(V - V_max) + probs <- exp_V / sum(exp_V) + probs <- pmax(probs, 1e-10) + probs <- probs / sum(probs) + + # Log-likelihood + log_lik <- log_lik + log(probs[choice]) + + # Mise à jour Q-learning + Q_new <- Q + + # Choix de l'alpha approprié + if (reward == -3000) { + alpha_used <- alpha_BS + } else if (reward == 3000) { + alpha_used <- alpha_JP + } else if (reward < 0) { + alpha_used <- alpha_loss + } else { + alpha_used <- alpha_gain + } + + # Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t)) + Q_new[choice] <- Q[choice] + alpha_used * (reward - Q[choice]) + + # Options non choisies : Q(t+1) = Q(t) * (1 - f) + not_chosen <- setdiff(1:n_arms, choice) + Q_new[not_chosen] <- Q[not_chosen] * (1 - forget[not_chosen]) + + Q <- Q_new + } + + if (return_negLL) { + return(-log_lik) + } else { + return(log_lik) + } +} + +# ============================================================================ +# CONFIGURATIONS DES MODÈLES EMBOÎTÉS +# ============================================================================ + +get_model_configs <- function() { + list( + HOMOGENEOUS = list( + name = "HOMOGENEOUS", + n_alpha = 1, + n_forget = 1, + n_lambda = 1, + has_rho = FALSE, + n_params = 3, + param_names = c("alpha", "forget", "lambda"), + lower = c(-5, -5, -3), + upper = c(5, 5, 3) + ), + + GAIN_LOSS = list( + name = "GAIN_LOSS", + n_alpha = 2, + n_forget = 1, + n_lambda = 1, + has_rho = FALSE, + n_params = 4, + param_names = c("alpha_loss", "alpha_gain", "forget", "lambda"), + lower = c(-5, -5, -5, -3), + upper = c(5, 5, 5, 3) + ), + + BIASED = list( + name = "BIASED", + n_alpha = 2, + n_forget = 4, + n_lambda = 4, + has_rho = FALSE, + n_params = 10, + param_names = c("alpha_loss", "alpha_gain", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4"), + lower = c(-5, -5, rep(-5, 4), rep(-3, 4)), + upper = c(5, 5, rep(5, 4), rep(3, 4)) + ), + + REE_BIASED_SIMPLE = list( + name = "REE_BIASED_SIMPLE", + n_alpha = 2, + n_forget = 1, + n_lambda = 1, + has_rho = TRUE, + n_params = 6, + param_names = c("alpha_loss", "alpha_gain", "forget", "lambda", + "rho_BS", "rho_JP"), + lower = c(-5, -5, -5, -3, -10, -10), + upper = c(5, 5, 5, 3, 10, 10) + ), + + REE_BIASED_COMPLEX = list( + name = "REE_BIASED_COMPLEX", + n_alpha = 2, + n_forget = 4, + n_lambda = 4, + has_rho = TRUE, + n_params = 12, + param_names = c("alpha_loss", "alpha_gain", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4", + "rho_BS", "rho_JP"), + lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10), + upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10) + ), + + REE_LEARNING_SIMPLE = list( + name = "REE_LEARNING_SIMPLE", + n_alpha = 4, + n_forget = 1, + n_lambda = 1, + has_rho = FALSE, + n_params = 6, + param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget", "lambda"), + lower = c(-5, -5, -5, -5, -5, -3), + upper = c(5, 5, 5, 5, 5, 3) + ), + + REE_LEARNING_COMPLEX = list( + name = "REE_LEARNING_COMPLEX", + n_alpha = 4, + n_forget = 4, + n_lambda = 4, + has_rho = FALSE, + n_params = 12, + param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4"), + lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)), + upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4)) + ), + + REE_LEARNING_BIASED_SIMPLE = list( + name = "REE_LEARNING_BIASED_SIMPLE", + n_alpha = 4, + n_forget = 1, + n_lambda = 1, + has_rho = TRUE, + n_params = 8, + param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget", "lambda", "rho_BS", "rho_JP"), + lower = c(-5, -5, -5, -5, -5, -3, -10, -10), + upper = c(5, 5, 5, 5, 5, 3, 10, 10) + ), + + REE_LEARNING_BIASED_COMPLEX = list( + name = "REE_LEARNING_BIASED_COMPLEX", + n_alpha = 4, + n_forget = 4, + n_lambda = 4, + has_rho = TRUE, + n_params = 14, + param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4", + "rho_BS", "rho_JP"), + lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10), + upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10) + ) + ) +} + +# ============================================================================ +# ESTIMATION POUR UN PARTICIPANT +# ============================================================================ + +fit_participant <- function(participant_data, model_config, n_runs = 5) { + + all_results <- vector("list", n_runs) + + for (run in 1:n_runs) { + set.seed(1000 * as.numeric(factor(model_config$name)) + run) + + result <- DEoptim( + fn = qlearning_generic, + lower = model_config$lower, + upper = model_config$upper, + data = participant_data, + model_config = model_config, + control = DEoptim.control( + itermax = 200, + trace = FALSE, + parallelType = 0, + NP = max(50, model_config$n_params * 10) + ) + ) + + all_results[[run]] <- list( + params = result$optim$bestmem, + negLL = result$optim$bestval + ) + } + + # Sélection du meilleur run + all_negLL <- sapply(all_results, function(x) x$negLL) + best_run <- which.min(all_negLL) + + negLL_sd <- sd(all_negLL) + negLL_range <- max(all_negLL) - min(all_negLL) + + params <- all_results[[best_run]]$params + negLL <- all_results[[best_run]]$negLL + + # Calcul de la Hessienne + hessian_result <- tryCatch({ + numDeriv::hessian(qlearning_generic, params, + data = participant_data, + model_config = model_config) + }, error = function(e) NULL) + + hessian_positive_definite <- FALSE + param_se <- rep(NA, model_config$n_params) + + if (!is.null(hessian_result)) { + eigenvalues <- eigen(hessian_result, only.values = TRUE)$values + hessian_positive_definite <- all(eigenvalues > 0) + + if (hessian_positive_definite) { + param_vcov <- tryCatch(solve(hessian_result), error = function(e) NULL) + if (!is.null(param_vcov)) { + param_se <- sqrt(diag(param_vcov)) + } + } + } + + # Création du tibble de résultats + result_df <- tibble( + model = model_config$name, + n_params = model_config$n_params, + negLL = negLL, + AIC = 2 * negLL + 2 * model_config$n_params, + BIC = 2 * negLL + model_config$n_params * log(nrow(participant_data)), + convergence_sd = negLL_sd, + convergence_range = negLL_range, + hessian_positive_definite = hessian_positive_definite, + converged = negLL_range < 1 + ) + + # Ajout des paramètres estimés + for (i in 1:model_config$n_params) { + result_df[[model_config$param_names[i]]] <- params[i] + result_df[[paste0("se_", model_config$param_names[i])]] <- param_se[i] + } + + return(result_df) +} + +# ============================================================================ +# ESTIMATION POUR TOUS LES PARTICIPANTS +# ============================================================================ + +fit_all_participants_all_models <- function(data, models_to_fit = NULL) { + + model_configs <- get_model_configs() + + if (!is.null(models_to_fit)) { + model_configs <- model_configs[models_to_fit] + } + + participants <- unique(data$participant_id) + + all_results <- list() + + for (model_name in names(model_configs)) { + cat("\n=== Fitting model:", model_name, "===\n") + + model_config <- model_configs[[model_name]] + + model_results <- map_df(participants, function(pid) { + cat(" Participant", pid, "\n") + + participant_data <- data %>% + filter(participant_id == pid) %>% + arrange(trial) + + fit <- fit_participant(participant_data, model_config) + fit %>% mutate(participant_id = pid, .before = 1) + }) + + all_results[[model_name]] <- model_results + } + + return(all_results) +} + +# ============================================================================ +# COMPARAISON DES MODÈLES EMBOÎTÉS +# ============================================================================ + +compare_nested_models <- function(all_results) { + + # Comparaison globale + global_comparison <- map_df(names(all_results), function(model_name) { + results <- all_results[[model_name]] + + tibble( + model = model_name, + n_params = unique(results$n_params), + n_converged = sum(results$converged), + mean_negLL = mean(results$negLL), + total_negLL = sum(results$negLL), + total_AIC = sum(results$AIC), + total_BIC = sum(results$BIC), + mean_convergence_range = mean(results$convergence_range) + ) + }) %>% + arrange(total_BIC) + + cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n") + print(global_comparison) + + # Meilleur modèle par participant (BIC) + best_models_per_participant <- map_df(names(all_results), function(model_name) { + all_results[[model_name]] %>% + select(participant_id, model, BIC, converged) + }) %>% + group_by(participant_id) %>% + slice_min(BIC, n = 1) %>% + ungroup() + + cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n") + print(table(best_models_per_participant$model)) + + # Comparaison par paires de modèles emboîtés + cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n") + + # Exemples de paires emboîtées + nested_pairs <- list( + c("HOMOGENEOUS", "GAIN_LOSS"), + c("GAIN_LOSS", "BIASED"), + c("GAIN_LOSS", "REE_BIASED_SIMPLE"), + c("REE_BIASED_SIMPLE", "REE_BIASED_COMPLEX"), + c("GAIN_LOSS", "REE_LEARNING_SIMPLE"), + c("REE_LEARNING_SIMPLE", "REE_LEARNING_COMPLEX"), + c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"), + c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX") + ) + + lrt_results <- map_df(nested_pairs, function(pair) { + simple_model <- pair[1] + complex_model <- pair[2] + + if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) { + simple_res <- all_results[[simple_model]] + complex_res <- all_results[[complex_model]] + + # LRT par participant + lrt_df <- simple_res %>% + select(participant_id, negLL_simple = negLL, converged_simple = converged) %>% + left_join( + complex_res %>% select(participant_id, negLL_complex = negLL, converged_complex = converged), + by = "participant_id" + ) %>% + filter(converged_simple, converged_complex) %>% + mutate( + LR_stat = 2 * (negLL_simple - negLL_complex), + df_diff = unique(complex_res$n_params) - unique(simple_res$n_params), + p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE), + significant = p_value < 0.05 + ) + + tibble( + simple_model = simple_model, + complex_model = complex_model, + n_participants = nrow(lrt_df), + pct_significant = mean(lrt_df$significant) * 100, + mean_LR = mean(lrt_df$LR_stat), + median_p = median(lrt_df$p_value) + ) + } + }) + + print(lrt_results) + + list( + global_comparison = global_comparison, + best_models_per_participant = best_models_per_participant, + lrt_results = lrt_results, + all_results = all_results + ) +} + +# ============================================================================ +# VISUALISATION +# ============================================================================ + +plot_model_comparison <- function(comparison_results) { + require(ggplot2) + require(patchwork) + + # 1. Comparaison BIC globale + p1 <- comparison_results$global_comparison %>% + mutate(model = fct_reorder(model, total_BIC)) %>% + ggplot(aes(x = model, y = total_BIC, fill = model)) + + geom_col() + + coord_flip() + + theme_minimal() + + labs(title = "Comparaison globale des modèles (BIC)", + subtitle = "Plus bas = meilleur") + + theme(legend.position = "none") + + # 2. Meilleur modèle par participant + p2 <- comparison_results$best_models_per_participant %>% + count(model) %>% + mutate(model = fct_reorder(model, n)) %>% + ggplot(aes(x = model, y = n, fill = model)) + + geom_col() + + coord_flip() + + theme_minimal() + + labs(title = "Meilleur modèle par participant", + y = "Nombre de participants") + + theme(legend.position = "none") + + # 3. Tests LRT + if (nrow(comparison_results$lrt_results) > 0) { + p3 <- comparison_results$lrt_results %>% + mutate(comparison = paste(simple_model, "→", complex_model)) %>% + ggplot(aes(x = fct_reorder(comparison, pct_significant), + y = pct_significant)) + + geom_col(fill = "steelblue") + + geom_hline(yintercept = 50, linetype = "dashed", color = "red") + + coord_flip() + + theme_minimal() + + labs(title = "Tests LRT entre modèles emboîtés", + y = "% participants avec p < 0.05", + x = "Comparaison") + } else { + p3 <- ggplot() + theme_void() + } + + # 4. Convergence par modèle + p4 <- map_df(names(comparison_results$all_results), function(model_name) { + comparison_results$all_results[[model_name]] %>% + select(model, convergence_range, converged) + }) %>% + ggplot(aes(x = model, y = convergence_range, fill = converged)) + + geom_boxplot() + + scale_y_log10() + + coord_flip() + + theme_minimal() + + labs(title = "Convergence par modèle", + y = "Range negLL (log scale)") + + (p1 | p2) / (p3 | p4) +} + +# ============================================================================ +# EXEMPLE D'UTILISATION +# ============================================================================ + +# Charger vos données +# data <- read_csv("votre_fichier.csv") +# Colonnes requises: participant_id, trial, choice, reward + +# Estimation de tous les modèles +# all_results <- fit_all_participants_all_models(data) + +# Ou seulement certains modèles +# all_results <- fit_all_participants_all_models( +# data, +# models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE", +# "REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE") +# ) + +# Comparaison +# comparison <- compare_nested_models(all_results) + +# Visualisation +# plot_model_comparison(comparison) + +# Sauvegarder les résultats +# for (model_name in names(all_results)) { +# write_csv(all_results[[model_name]], +# paste0("results_", model_name, ".csv")) +# } +# write_csv(comparison$global_comparison, "global_comparison.csv") +# write_csv(comparison$best_models_per_participant, "best_models.csv") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8d091ab --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +numpy +scipy +pandas \ No newline at end of file diff --git a/results/model0/mle.csv b/results/model0/mle.csv new file mode 100644 index 0000000..1c670f1 --- /dev/null +++ b/results/model0/mle.csv @@ -0,0 +1,61 @@ +participant,alpha,f,nll +qfmtmjjy,0.582784971895418,0.999999066293681,552.9655592003581 +l5hj4bqu,0.9999999999056042,5.46564082409206e-13,547.8314792895943 +ax6yt76a,0.031045765133638766,0.999999771173566,553.9560713683154 +n7zr7k22,0.06065557843979855,0.9999999663593895,553.8393776741615 +n274eufg,2.3664079880887658e-08,0.9967507965525135,554.5177445759583 +d3n1xm9e,0.04203999064289608,0.9999854195433653,554.2427744376569 +e41xqg7i,0.007942222018805265,8.0242615476356e-11,554.5121903640234 +qdl2kwne,1.0,6.826428647305449e-46,492.2241079443528 +8y0d23i5,0.777177246966759,7.402298460381941e-15,534.5235366710108 +1e1vfxfy,0.3330278516103173,0.0850735445577119,553.9630735526995 +4vfjhtfy,0.1122311594415092,1.871749156748348e-26,552.2074993081454 +kmbtnydb,0.13667778040270864,2.0884909923554726e-09,547.5307630064009 +a3whtqnm,3.629020342740599e-07,0.00328190336364656,554.5177410374322 +bqckq56i,0.03402038741621446,7.86918333059074e-08,552.9721272204469 +wtmppyq1,0.06906140463867033,9.607058911881154e-08,553.7477997790903 +gyboex2p,0.9999999999700879,6.2820427656554785e-12,545.1872533519312 +6v2kv490,4.235265828311803e-09,0.2648614681419073,554.5177444590045 +goe4ggy1,0.7254990519220672,1.9745911879265448e-08,548.0978286404306 +wjhrr05w,0.09471712411279669,0.8213415076416049,554.5092420387667 +97wdqw3s,0.9999999999995051,9.901870701857771e-13,500.4256381607084 +u8hygigo,0.05070763094787091,0.8321426728828714,551.7461707772945 +cf10378m,0.2095885843287755,1.6521145726363753e-08,553.3156063628437 +v5r7mq91,0.5253114772877784,0.0005372516169283673,551.7146877045956 +3x6nteue,0.2686405242216113,0.0003538486737042658,551.7640263917281 +xecqf0dd,1.0,2.3760270339932486e-14,531.703008010426 +m9i1kvu6,0.05114247893954163,1.3281881318099996e-12,553.2707089568645 +gp3vvd4k,0.9999999999999882,1.1337302888068697e-34,516.2411532388203 +2w7l1ve3,0.17190627716472015,5.340979032968921e-10,551.4238344441382 +inpc0pfu,3.4061084339887766e-11,0.9992991841689133,554.5177444479662 +x18jrbof,1.0,2.224445142889756e-43,509.74931877274776 +nao85d8o,0.009160817690823834,5.1534149290441666e-14,554.4567848446477 +owfwd7iz,0.05978783217663562,1.7357980024613605e-11,552.6129603387225 +iiupv9e6,0.09814668824802089,1.6943395241865024e-09,547.4859266833736 +rht8c10d,0.9478576700137822,0.9999999994412778,553.674648932509 +uc5u8lob,0.0578416087666933,0.9999999989655655,553.1812054538608 +idc0abhp,0.037067505840974964,1.928656015149807e-09,552.4401642516339 +gfy4u3y9,3.0467043647872075e-11,0.967600368549238,554.5177444480236 +vd2z262g,0.05823220877937151,3.868526723600325e-15,552.8274284304949 +gaq9bw7l,0.014725646990138265,1.085129359971333e-12,554.2958912874782 +92minocl,5.36098831541956e-09,0.37631916992506054,554.5177444683173 +5kyg1fo7,0.048134854427945,1.7177194138139772e-09,551.9780048138557 +a95e64ov,0.43712798452105045,2.5556448950394826e-09,545.509286092445 +yasfsrkr,0.21329794299173746,2.059294640948317e-245,546.9216420033767 +qfmtmjjy,1.0,6.331953602927268e-08,553.5640344190499 +p6p2uc66,0.09705693757894857,0.6449707555360135,551.7811547180477 +8mrt2jk8,0.9999999921780748,2.542845157662274e-11,527.4195785127077 +74ubi75c,0.37099733903488924,0.9999999964659654,554.2958824030542 +l5hj4bqu,0.42321676612719833,0.21940814540809023,546.7609390875771 +k2mcerli,0.0462287559453904,0.9999757375629387,553.3444107444034 +nrwlu1nf,2.5695476201442106e-08,0.005983916620578978,554.5177441485276 +ax6yt76a,7.09619494188235e-08,0.9999926916808033,554.3761235503644 +3t3sh2es,0.07493989361956131,2.9956578631889184e-11,549.3828475837245 +n7zr7k22,2.444167462227958e-08,0.9999999539800323,502.1996879300243 +7srzfhpq,0.3482551579565603,2.6477991358741416e-09,545.4042735754294 +4hfaww7m,0.02315510868424729,1.956763314467002e-09,553.3006552174948 +n274eufg,2.1898408386418047e-07,0.0873503997081469,552.8203462427547 +yk2obfqv,0.01618059281802409,1.2650645235924713e-09,553.7560186105892 +ng1u69iy,0.20445541057918212,1.868337206801065e-12,540.1518009334501 +tkngvs6x,0.195823881743137,1.0911950272637065e-09,548.2683590155018 +j1cshr0g,0.41032544288457223,1.8233454044592947e-40,553.1190812699057 diff --git a/results/model0/unrestrained-mle.csv b/results/model0/unrestrained-mle.csv new file mode 100644 index 0000000..423510c --- /dev/null +++ b/results/model0/unrestrained-mle.csv @@ -0,0 +1,211 @@ +alpha,lambda,nll,opt_result,participant +0.0019909996639928752,0.0784037743181824,524.515296073272," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 524.515296073272 + x: [-6.217e+00 -2.546e+00] + nit: 24 + jac: [ 1.137e-05 -3.411e-05] + nfev: 96 + njev: 32 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",qfmtmjjy +0.0041258076334695195,0.17916292668471764,232.26294177560447," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 232.26294177560447 + x: [-5.486e+00 -1.719e+00] + nit: 16 + jac: [ 3.411e-05 3.411e-05] + nfev: 51 + njev: 17 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",l5hj4bqu +8.580076240145284e-10,42804.74116720703,551.346218644511," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 551.346218644511 + x: [-2.088e+01 1.066e+01] + nit: 37 + jac: [-1.819e-04 -1.933e-04] + nfev: 129 + njev: 43 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",ax6yt76a +9.183501708224334e-11,2938700.1970509873,504.4244914945921," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 504.4244914945921 + x: [-2.311e+01 1.489e+01] + nit: 34 + jac: [ 3.490e-03 3.490e-03] + nfev: 126 + njev: 42 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",n7zr7k22 +1.3285404087992438e-12,150903313.1944923,526.8342605685835," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 526.8342605685835 + x: [-2.735e+01 1.883e+01] + nit: 36 + jac: [-6.821e-05 -6.821e-05] + nfev: 126 + njev: 42 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",n274eufg +0.01649599700173123,0.01686233697780474,499.4082573033442," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 499.4082573033442 + x: [-4.088e+00 -4.083e+00] + nit: 23 + jac: [ 2.842e-05 -5.684e-05] + nfev: 81 + njev: 27 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",d3n1xm9e +2.759963002574895e-06,0.975782284437965,554.5108100039844," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 554.5108100039844 + x: [-1.280e+01 -2.452e-02] + nit: 37 + jac: [-1.251e-04 -1.592e-04] + nfev: 147 + njev: 49 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",e41xqg7i +0.8250744843164314,0.0010126000014369413,470.3492586229785," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 470.3492586229785 + x: [ 1.551e+00 -6.895e+00] + nit: 22 + jac: [-5.116e-05 -4.547e-05] + nfev: 69 + njev: 23 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",qdl2kwne +0.0067738863236796905,0.03748214730335093,471.11474910271704," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 471.11474910271704 + x: [-4.988e+00 -3.284e+00] + nit: 21 + jac: [ 3.411e-05 2.274e-05] + nfev: 69 + njev: 23 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",8y0d23i5 +0.3152175910708728,7.705164135712456e-05,554.438430475421," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 554.438430475421 + x: [-7.758e-01 -9.471e+00] + nit: 24 + jac: [ 6.821e-05 2.274e-05] + nfev: 75 + njev: 25 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",1e1vfxfy +3.0097559768213683e-09,48217.30222264478,537.6960374952457," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 537.6960374952457 + x: [-1.962e+01 1.078e+01] + nit: 40 + jac: [ 2.069e-03 2.069e-03] + nfev: 129 + njev: 43 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",4vfjhtfy +0.017080076809904806,0.01834882372168268,475.4223519403967," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 475.4223519403967 + x: [-4.053e+00 -3.998e+00] + nit: 17 + jac: [-1.705e-05 2.842e-05] + nfev: 63 + njev: 21 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",kmbtnydb +6.437703242339829e-08,1111.8568124210485,551.9888338131033," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 551.9888338131033 + x: [-1.656e+01 7.014e+00] + nit: 44 + jac: [-6.594e-04 -6.594e-04] + nfev: 150 + njev: 50 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",a3whtqnm +3.698305553137226e-11,10937793.673686024,448.24507173971307," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 448.24507173971307 + x: [-2.402e+01 1.621e+01] + nit: 36 + jac: [-1.046e-03 -1.052e-03] + nfev: 144 + njev: 48 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",bqckq56i +2.3474015194739323e-09,27880.14281193776,550.037229129203," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 550.037229129203 + x: [-1.987e+01 1.024e+01] + nit: 37 + jac: [ 2.274e-05 2.274e-05] + nfev: 117 + njev: 39 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",wtmppyq1 +0.006699227131566248,0.02396872654682817,498.04299030539954," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 498.04299030539954 + x: [-4.999e+00 -3.731e+00] + nit: 21 + jac: [-3.411e-05 -6.253e-05] + nfev: 84 + njev: 28 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",gyboex2p +8.873267040119318e-08,2.6736837774273013e-06,554.5177445402869," message: CONVERGENCE: NORM OF PROJECTED GRADIENT <= PGTOL + success: True + status: 0 + fun: 554.5177445402869 + x: [-1.624e+01 -1.283e+01] + nit: 32 + jac: [ 0.000e+00 0.000e+00] + nfev: 105 + njev: 35 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",6v2kv490 +2.6887086013618105e-11,4106709.344670519,527.0250476183091," message: CONVERGENCE: NORM OF PROJECTED GRADIENT <= PGTOL + success: True + status: 0 + fun: 527.0250476183091 + x: [-2.434e+01 1.523e+01] + nit: 38 + jac: [ 0.000e+00 0.000e+00] + nfev: 123 + njev: 41 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",goe4ggy1 +5.703446820155234e-10,141678.15993334388,546.8526377885472," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 546.8526377885472 + x: [-2.128e+01 1.186e+01] + nit: 35 + jac: [ 1.251e-04 1.251e-04] + nfev: 117 + njev: 39 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",wjhrr05w +6.942499761718846e-08,17717.856856409046,331.41259185330904," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 331.41259185330904 + x: [-1.648e+01 9.782e+00] + nit: 29 + jac: [-2.893e-03 -2.950e-03] + nfev: 93 + njev: 31 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",97wdqw3s +1.37409628136384e-10,20060552.496022884,28.39754844462764," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 28.39754844462764 + x: [-2.271e+01 1.681e+01] + nit: 35 + jac: [-1.435e-04 -1.435e-04] + nfev: 129 + njev: 43 + hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>",u8hygigo diff --git a/results/model1/mle.csv b/results/model1/mle.csv new file mode 100644 index 0000000..b0270e4 --- /dev/null +++ b/results/model1/mle.csv @@ -0,0 +1,4 @@ +participant,alpha,lambda,f,nll +qfmtmjjy,1.1193176839005968e-16,3.4183388875710508e+16,1.0,469.4113688856362 +l5hj4bqu,0.0030442296624999866,928.7290018408265,0.026793671988813828,225.6172699416345 +ax6yt76a,9.376652676467416e-20,4.30661887267533e+19,1.0,441.8673631389228 diff --git a/results/model1/unrestrained-mle.csv b/results/model1/unrestrained-mle.csv new file mode 100644 index 0000000..5efdc56 --- /dev/null +++ b/results/model1/unrestrained-mle.csv @@ -0,0 +1,191 @@ +alpha,lambda,f,nll,opt_result,participant +8.871704778471776e-11,14376474.143789845,0.9999999984752381,469.33695814123934," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 469.33695814123934 + x: [-2.315e+01 1.648e+01 2.030e+01] + nit: 29 + jac: [ 8.072e-04 8.072e-04 0.000e+00] + nfev: 148 + njev: 37 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",qfmtmjjy +0.0030548027746643705,0.3085218170752524,0.02685309882404673,225.54778324822848," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 225.54778324822848 + x: [-5.788e+00 -1.176e+00 -3.590e+00] + nit: 26 + jac: [ 2.558e-05 2.842e-06 1.421e-05] + nfev: 132 + njev: 33 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",l5hj4bqu +3.2198476500118582e-15,417862591134.8987,0.9999999999999423,441.6555549359772," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 441.6555549359772 + x: [-3.337e+01 2.676e+01 3.048e+01] + nit: 21 + jac: [ 7.958e-05 7.958e-05 0.000e+00] + nfev: 132 + njev: 33 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",ax6yt76a +1.3010725964224816e-09,1225860.6537195262,0.9999999937812794,466.43975387712266," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 466.43975387712266 + x: [-2.046e+01 1.402e+01 1.890e+01] + nit: 32 + jac: [-5.116e-05 -5.116e-05 0.000e+00] + nfev: 140 + njev: 35 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",n7zr7k22 +1.160548917142861e-07,1.6863783919195292e-05,0.9016665769083677,554.5177444822273," message: CONVERGENCE: NORM OF PROJECTED GRADIENT <= PGTOL + success: True + status: 0 + fun: 554.5177444822273 + x: [-1.597e+01 -1.099e+01 2.216e+00] + nit: 29 + jac: [ 0.000e+00 0.000e+00 0.000e+00] + nfev: 120 + njev: 30 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",n274eufg +9.579437532994249e-17,6131931277693.485,0.9999999998655666,539.3988636277782," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 539.3988636277782 + x: [-3.688e+01 2.944e+01 2.273e+01] + nit: 24 + jac: [-1.137e-05 0.000e+00 0.000e+00] + nfev: 136 + njev: 34 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",d3n1xm9e +0.0002172028784807661,0.012459143465035454,5.585119121446817e-06,554.5108523711202," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 554.5108523711202 + x: [-8.434e+00 -4.385e+00 -1.210e+01] + nit: 37 + jac: [-3.411e-05 -1.137e-05 4.547e-05] + nfev: 248 + njev: 62 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",e41xqg7i +0.8250800085276067,0.001012589017492184,1.2649395927827455e-16,470.3492586242423," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 470.3492586242423 + x: [ 1.551e+00 -6.895e+00 -3.661e+01] + nit: 28 + jac: [-2.274e-05 -3.752e-04 0.000e+00] + nfev: 148 + njev: 37 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",qdl2kwne +1.2610421898523663e-18,494323555893264.56,0.04295594209878908,436.20220706531273," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 436.20220706531273 + x: [-4.121e+01 3.383e+01 -3.104e+00] + nit: 38 + jac: [-1.705e-04 -1.592e-04 -6.821e-05] + nfev: 232 + njev: 58 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",8y0d23i5 +1.116371444999944e-07,2524.9268841535963,0.21851369775877533,553.2746446547244," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 553.2746446547244 + x: [-1.601e+01 7.834e+00 -1.274e+00] + nit: 40 + jac: [ 1.592e-04 1.592e-04 1.592e-04] + nfev: 172 + njev: 43 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",1e1vfxfy +6.565360147334132e-11,6632821.084165682,0.012003731763428675,513.6175539157593," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 513.6175539157593 + x: [-2.345e+01 1.571e+01 -4.410e+00] + nit: 46 + jac: [-4.968e-03 -4.957e-03 1.223e-02] + nfev: 340 + njev: 85 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",4vfjhtfy +0.017079822973559262,0.018348802029635628,9.9230306968482e-10,475.42235386434925," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 475.42235386434925 + x: [-4.053e+00 -3.998e+00 -2.073e+01] + nit: 36 + jac: [-7.560e-04 -1.177e-03 0.000e+00] + nfev: 204 + njev: 51 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",kmbtnydb +1.3990137307531058e-06,51.16930938606098,4.982987068878401e-11,551.9888438686838," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 551.9888438686838 + x: [-1.348e+01 3.935e+00 -2.372e+01] + nit: 53 + jac: [ 2.274e-05 2.274e-05 0.000e+00] + nfev: 388 + njev: 97 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",a3whtqnm +2.0976159715180872e-12,425089944.21903425,0.3185943891474155,405.3794868584693," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 405.3794868584693 + x: [-2.689e+01 1.987e+01 -7.602e-01] + nit: 24 + jac: [ 1.137e-05 1.137e-05 1.705e-05] + nfev: 140 + njev: 35 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",bqckq56i +4.453063202131555e-14,1469667428.4008932,4.095080332704109e-15,550.0372290171044," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 550.0372290171044 + x: [-3.074e+01 2.111e+01 -3.313e+01] + nit: 46 + jac: [-3.411e-05 -3.411e-05 0.000e+00] + nfev: 236 + njev: 59 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",wtmppyq1 +0.006699186509785594,0.02396884830138067,1.786102972042448e-13,498.042990305876," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 498.042990305876 + x: [-4.999e+00 -3.731e+00 -2.935e+01] + nit: 45 + jac: [-1.705e-05 5.116e-05 0.000e+00] + nfev: 324 + njev: 81 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",gyboex2p +1.7876073951029804e-08,1.7329412122410143e-06,0.9719326721535358,554.5177444480751," message: CONVERGENCE: NORM OF PROJECTED GRADIENT <= PGTOL + success: True + status: 0 + fun: 554.5177444480751 + x: [-1.784e+01 -1.327e+01 3.545e+00] + nit: 28 + jac: [ 0.000e+00 0.000e+00 0.000e+00] + nfev: 124 + njev: 31 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",6v2kv490 +1.307936563746064e-06,6.0156094409861654e-05,0.9112155489671344,554.5177444919409," message: CONVERGENCE: NORM OF PROJECTED GRADIENT <= PGTOL + success: True + status: 0 + fun: 554.5177444919409 + x: [-1.355e+01 -9.719e+00 2.329e+00] + nit: 28 + jac: [ 0.000e+00 0.000e+00 0.000e+00] + nfev: 124 + njev: 31 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",goe4ggy1 +4.070711600689151e-05,0.9698096110426262,0.8087323463467103,554.5067963631062," message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH + success: True + status: 0 + fun: 554.5067963631062 + x: [-1.011e+01 -3.066e-02 1.442e+00] + nit: 33 + jac: [ 7.958e-05 9.095e-05 1.023e-04] + nfev: 144 + njev: 36 + hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>",wjhrr05w diff --git a/results/model2/mle.csv b/results/model2/mle.csv new file mode 100644 index 0000000..381f339 --- /dev/null +++ b/results/model2/mle.csv @@ -0,0 +1,82 @@ +participant,alpha_gain,alpha_loss,nll,model +qfmtmjjy,1.0,6.331953602927268e-08,553.5640344190499,model2_loss_gain +qfmtmjjy,1.0,6.331953602927268e-08,553.5640344190499,model2_loss_gain +l5hj4bqu,0.42321676612719833,0.21940814540809023,546.7609390875771,model2_loss_gain +ax6yt76a,7.09619494188235e-08,0.9999926916808033,554.3761235503644,model2_loss_gain +n7zr7k22,2.444167462227958e-08,0.9999999539800323,502.1996879300243,model2_loss_gain +n274eufg,2.1898408386418047e-07,0.0873503997081469,552.8203462427547,model2_loss_gain +d3n1xm9e,7.557448923824224e-10,0.49316262300646374,536.7090539494609,model2_loss_gain +e41xqg7i,3.620211097661501e-05,0.013882225967113861,554.5032044585822,model2_loss_gain +qdl2kwne,0.1488199986109225,0.999999915877716,490.03138247284346,model2_loss_gain +8y0d23i5,3.1353878914548706e-07,0.9999990782665764,508.52029649601076,model2_loss_gain +1e1vfxfy,2.6828870011444275e-06,0.8464278794999273,548.0820963166587,model2_loss_gain +4vfjhtfy,4.789637034959933e-07,0.9999996758398828,540.0816425215404,model2_loss_gain +kmbtnydb,2.5482845557987327e-07,0.2127432312542269,545.8392364682928,model2_loss_gain +a3whtqnm,5.853822199059074e-08,0.5990978930901857,532.5704878175919,model2_loss_gain +bqckq56i,0.45134623957080383,1.6739867943334115e-08,550.7972119494303,model2_loss_gain +wtmppyq1,4.506323507164776e-07,0.20132652284582794,551.7209953750038,model2_loss_gain +gyboex2p,0.7315951368264882,0.9999994624999908,544.9177728210519,model2_loss_gain +6v2kv490,7.176003008630198e-16,0.9996580030679997,554.8198842643062,model2_loss_gain +goe4ggy1,0.08928820972079568,0.923176664636646,545.2746923213128,model2_loss_gain +wjhrr05w,2.2875552027867153e-07,0.8334131430433657,551.3703828886879,model2_loss_gain +97wdqw3s,0.99999908618204,0.9999999293659182,500.42564088776396,model2_loss_gain +u8hygigo,0.061751261020003355,5.6914260715717614e-08,548.7283246258174,model2_loss_gain +cf10378m,0.053136759386351734,0.9999998878232516,551.7779144590984,model2_loss_gain +v5r7mq91,0.19107124279103968,4.14211806035215e-07,549.7314348984426,model2_loss_gain +3x6nteue,0.07757806051692856,1.1763894527174054e-07,549.2756329288749,model2_loss_gain +xecqf0dd,0.9996431490434584,0.9999999088133392,531.7030088087863,model2_loss_gain +m9i1kvu6,0.292942196627816,0.8319760919457827,553.0350995516853,model2_loss_gain +gp3vvd4k,0.6544092186694408,0.9999993531057779,515.8326659000236,model2_loss_gain +2w7l1ve3,0.049817764580400326,0.2096126310457826,551.2154842831172,model2_loss_gain +inpc0pfu,1.171384230140469e-06,0.14850086667882983,551.4114139367338,model2_loss_gain +x18jrbof,0.35420927180951134,0.9999998419223542,508.2621175941522,model2_loss_gain +nao85d8o,7.728752512680148e-08,0.9371913104055626,537.088205130901,model2_loss_gain +owfwd7iz,0.5467286785975776,3.957100499126383e-07,551.0093016982407,model2_loss_gain +iiupv9e6,6.305736445494951e-06,0.1553419191193692,546.2786469156417,model2_loss_gain +rht8c10d,0.999998604627353,6.331553332137362e-07,551.8406870812335,model2_loss_gain +uc5u8lob,8.413515847164314e-07,0.5090831419625563,542.2055749945935,model2_loss_gain +idc0abhp,0.0786066114237155,1.0542693009318418e-09,549.7369816839649,model2_loss_gain +gfy4u3y9,4.722094779504361e-08,0.999992910214223,555.7513263189353,model2_loss_gain +vd2z262g,6.551046215106814e-08,0.4171153617163101,542.3096018602956,model2_loss_gain +gaq9bw7l,1.1726121173821847e-05,0.4067721222705871,554.0727455527086,model2_loss_gain +92minocl,2.2726345955612876e-07,0.5290818745238947,545.3325320271144,model2_loss_gain +5kyg1fo7,1.2470648028542125e-07,0.13842924629322795,549.7303076122271,model2_loss_gain +a95e64ov,0.08742389212343478,0.5826743191349707,543.8881933459938,model2_loss_gain +yasfsrkr,0.052784864734305,0.3155278575042451,545.4196656737446,model2_loss_gain +p6p2uc66,0.9999999455603333,3.3458261911750984e-08,547.7628819441434,model2_loss_gain +8mrt2jk8,0.99999732297481,0.9999998785643968,527.4195825672078,model2_loss_gain +74ubi75c,1.5334485228066952e-07,0.9999997912052608,507.52748762945,model2_loss_gain +k2mcerli,4.08062572137162e-08,0.19025233105851136,548.0592429198056,model2_loss_gain +nrwlu1nf,3.0404252738935575e-08,0.31608280812005457,549.0292739069076,model2_loss_gain +3t3sh2es,0.8249784093592426,0.48173009573483305,550.8255484189891,model2_loss_gain +7srzfhpq,3.970915388175658e-07,0.5566323247988982,540.2455464988313,model2_loss_gain +4hfaww7m,1.1892101808930413e-06,0.06840288119107804,552.473445790118,model2_loss_gain +yk2obfqv,1.3807351068135839e-06,0.993205019802573,555.6478666383708,model2_loss_gain +ng1u69iy,1.1802499187075957e-08,0.35483082886603434,533.8987961562939,model2_loss_gain +tkngvs6x,2.593512222239559e-06,0.44609953046479717,545.4166798043648,model2_loss_gain +j1cshr0g,0.36040895629733516,0.5832026620438543,552.9572936835381,model2_loss_gain +o7om1jen,2.3329330146721812e-08,0.3028246183831722,549.653042653292,model2_loss_gain +5w1y4e6s,0.9999973367121154,0.9999999766450439,504.60432280755185,model2_loss_gain +8dxvwlpc,0.08995510228445398,1.7168811661739754e-07,550.8164507448334,model2_loss_gain +8ghzl9o9,0.6891432497828257,0.7859147951044387,549.908798836067,model2_loss_gain +mmh62k83,1.1634105601886457e-08,0.999999848031814,538.3616162789655,model2_loss_gain +167y1h5i,0.15854439946664312,0.9999997511544267,547.6961304214166,model2_loss_gain +rfqmmb1e,0.05873960866856111,1.1141297455053465e-08,551.9305982015765,model2_loss_gain +6fajoau3,8.843376612924671e-07,0.1990406633896888,550.4247946752638,model2_loss_gain +j84pqw2m,0.10134622377518297,0.9999999271154836,489.36197096912457,model2_loss_gain +pz3wg8u0,0.9999795583215213,0.9999999052839689,491.5285933344341,model2_loss_gain +kwvbd3su,4.032274791578176e-06,0.08023742791079048,553.452284290126,model2_loss_gain +uw0ebq3n,4.35018656332699e-07,0.9999772938158631,554.7977781728154,model2_loss_gain +3uodgqnr,0.9998263255919533,0.2905615093179284,546.9326572266009,model2_loss_gain +i0g5hocs,3.039034688388457e-06,0.2937809186233787,542.4067826459491,model2_loss_gain +vqptyp2x,0.1296311132473564,0.999999316589332,533.5635483449804,model2_loss_gain +jfir3t88,2.0940724443610376e-07,0.8842319795054664,556.2116376937038,model2_loss_gain +kp8bwxso,9.208005798129224e-09,0.9999996182780255,550.847509163659,model2_loss_gain +nm7b54pf,1.681812725386848e-07,0.24251287386386436,545.1105744763561,model2_loss_gain +rh8bt6c8,2.293740670538188e-07,0.9999999695126307,527.4634975679522,model2_loss_gain +yrvhkfwj,1.8022098166473515e-07,0.9999991347681113,507.4038877186516,model2_loss_gain +ahceacmk,7.868228047088838e-08,0.9999998157421053,513.7805503460337,model2_loss_gain +ga9ow9zr,0.3298128065195124,1.4256285650099293e-06,549.1969358882476,model2_loss_gain +p9hhh9iw,0.9999909345665606,1.0319072839939956e-07,550.3175663358292,model2_loss_gain +odu11wbc,0.13289158187773722,0.9999983760379403,554.5549374937649,model2_loss_gain +cjwflpfp,0.5597238185910073,0.9999986205440585,524.6505504348717,model2_loss_gain