Rewriting with parallelization

This commit is contained in:
Louis Lacoste 2025-12-01 21:36:14 +01:00
parent 57b53629c7
commit 2670a1ecd2

View file

@ -5,27 +5,60 @@
library(tidyverse) library(tidyverse)
library(DEoptim) library(DEoptim)
library(numDeriv) library(numDeriv)
# foreach future
library(foreach)
library(doFuture)
library(future.callr)
plan(callr, workers = future::availableCores(omit = 1L))
# ============================================================================ # ============================================================================
# FONCTION GÉNÉRIQUE DE Q-LEARNING # FONCTION GÉNÉRIQUE DE Q-LEARNING
# ============================================================================ # ============================================================================
qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
# Normalise noms de colonnes et conversion des choix en indices numériques
# Conversion des choix en indices numériques if (!("button_value" %in% names(data)) && ("reward" %in% names(data))) {
if (is.factor(data$button_name) || is.character(data$button_name)) { data$button_value <- data$reward
}
if (!("button_name" %in% names(data))) {
if ("option" %in% names(data)) {
data$button_name <- data$option
} else if (("choice" %in% names(data)) && is.character(data$choice)) {
data$button_name <- data$choice
}
}
# If 'choice' is numeric (prepared data), use it directly
if (("choice" %in% names(data)) && is.numeric(data$choice)) {
data$choice_idx <- data$choice
} else if (is.factor(data$button_name) || is.character(data$button_name)) {
choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable") choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable")
data$choice_idx <- match(as.character(data$button_name), choice_levels) data$choice_idx <- match(as.character(data$button_name), choice_levels)
} else { } else if ("button_name" %in% names(data)) {
data$choice_idx <- data$button_name data$choice_idx <- data$button_name
} }
# Robustness: if mapping produced NAs, try remapping or fail fast with large penalty
if (any(is.na(data$choice_idx))) {
known_levels <- c("antifragile", "fragile", "robuste", "vulnerable")
if (!any(is.na(match(as.character(data$button_name), known_levels)))) {
data$choice_idx <- match(as.character(data$button_name), known_levels)
} else {
if (return_negLL) {
return(1e6)
} else {
return(-1e6)
}
}
}
n_arms <- 4 n_arms <- 4
n_trials <- nrow(data) n_trials <- nrow(data)
# Extraction des paramètres selon la configuration du modèle # Extraction des paramètres selon la configuration du modèle
param_idx <- 1 param_idx <- 1
# ALPHA(S) # ALPHA(S)
if (model_config$n_alpha == 1) { if (model_config$n_alpha == 1) {
alpha_loss <- alpha_gain <- alpha_BS <- alpha_JP <- plogis(params[param_idx]) alpha_loss <- alpha_gain <- alpha_BS <- alpha_JP <- plogis(params[param_idx])
@ -43,7 +76,7 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
alpha_JP <- plogis(params[param_idx + 3]) alpha_JP <- plogis(params[param_idx + 3])
param_idx <- param_idx + 4 param_idx <- param_idx + 4
} }
# FORGET(S) # FORGET(S)
if (model_config$n_forget == 1) { if (model_config$n_forget == 1) {
forget <- rep(plogis(params[param_idx]), n_arms) forget <- rep(plogis(params[param_idx]), n_arms)
@ -52,7 +85,7 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
forget <- plogis(params[param_idx:(param_idx + 3)]) forget <- plogis(params[param_idx:(param_idx + 3)])
param_idx <- param_idx + 4 param_idx <- param_idx + 4
} }
# LAMBDA(S) # LAMBDA(S)
if (model_config$n_lambda == 1) { if (model_config$n_lambda == 1) {
lambda <- rep(exp(params[param_idx]), n_arms) lambda <- rep(exp(params[param_idx]), n_arms)
@ -61,52 +94,76 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
lambda <- exp(params[param_idx:(param_idx + 3)]) lambda <- exp(params[param_idx:(param_idx + 3)])
param_idx <- param_idx + 4 param_idx <- param_idx + 4
} }
# RHO(S) - Biais pour événements rares # RHO(S) - Biais pour événements rares
if (model_config$has_rho) { if (model_config$has_rho) {
rho_BS <- params[param_idx] # BS avoidance rho_BS <- params[param_idx] # BS avoidance
rho_JP <- params[param_idx + 1] # JP seeking rho_JP <- params[param_idx + 1] # JP seeking
param_idx <- param_idx + 2 param_idx <- param_idx + 2
} else { } else {
rho_BS <- rho_JP <- 0 rho_BS <- rho_JP <- 0
} }
# Detect if rare events actually occur in this participant's data
has_BS_seen <- any(data$button_value == -3000, na.rm = TRUE)
has_JP_seen <- any(data$button_value == 3000, na.rm = TRUE)
# If an REE type was never observed, neutralize its rho to avoid non-identifiability
if (model_config$has_rho) {
if (!has_BS_seen) rho_BS <- 0
if (!has_JP_seen) rho_JP <- 0
}
# Initialisation des Q-values # Initialisation des Q-values
Q <- rep(0, n_arms) Q <- rep(0, n_arms)
log_lik <- 0 log_lik <- 0
for (t in 1:n_trials) { for (t in 1:n_trials) {
choice <- data$choice_idx[t] choice <- data$choice_idx[t]
reward <- data$button_value[t] reward <- data$button_value[t]
# Calcul des valeurs subjectives V(t) # Calcul des valeurs subjectives V(t)
V <- lambda * Q V <- lambda * Q
# Ajout des biais pour événements rares si le modèle le permet # Ajout des biais pour événements rares si le modèle le permet
if (model_config$has_rho) { if (model_config$has_rho) {
# Identification des options susceptibles de produire BS/JP # Identification des options susceptibles de produire BS/JP
# antifragile (1) = JP possible, fragile (2) = BS possible # antifragile (1) = JP possible, fragile (2) = BS possible
# vulnerable (4) = BS et JP possibles # vulnerable (4) = BS et JP possibles
V[1] <- V[1] + rho_JP # antifragile V[1] <- V[1] + rho_JP # antifragile
V[2] <- V[2] + rho_BS # fragile V[2] <- V[2] + rho_BS # fragile
V[4] <- V[4] + rho_BS + rho_JP # vulnerable V[4] <- V[4] + rho_BS + rho_JP # vulnerable
} }
# Softmax # Softmax
V_max <- max(V) V_max <- max(V)
exp_V <- exp(V - V_max) exp_V <- exp(V - V_max)
probs <- exp_V / sum(exp_V) probs <- exp_V / sum(exp_V)
probs <- pmax(probs, 1e-10) probs <- pmax(probs, 1e-10)
probs <- probs / sum(probs) probs <- probs / sum(probs)
# Log-likelihood # Log-likelihood
log_lik <- log_lik + log(probs[choice]) log_lik <- log_lik + log(probs[choice])
# Mise à jour Q-learning # Mise à jour Q-learning
Q_new <- Q Q_new <- Q
# Choix de l'alpha approprié # Choix de l'alpha approprié
if (reward == -3000) { # if (reward == -3000) {
# alpha_used <- alpha_BS
# } else if (reward == 3000) {
# alpha_used <- alpha_JP
# } else if (reward < 0) {
# alpha_used <- alpha_loss
# } else {
# alpha_used <- alpha_gain
# }
# Fix when there are no extreme rewards while taking them into account
if (is.na(reward)) {
# skip trials with missing reward
next
} else if (reward == -3000) {
alpha_used <- alpha_BS alpha_used <- alpha_BS
} else if (reward == 3000) { } else if (reward == 3000) {
alpha_used <- alpha_JP alpha_used <- alpha_JP
@ -115,17 +172,17 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
} else { } else {
alpha_used <- alpha_gain alpha_used <- alpha_gain
} }
# Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t)) # Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t))
Q_new[choice] <- Q[choice] + alpha_used * (reward - Q[choice]) Q_new[choice] <- Q[choice] + alpha_used * (reward - Q[choice])
# Options non choisies : Q(t+1) = Q(t) * (1 - f) # Options non choisies : Q(t+1) = Q(t) * (1 - f)
not_chosen <- setdiff(1:n_arms, choice) not_chosen <- setdiff(1:n_arms, choice)
Q_new[not_chosen] <- Q[not_chosen] * (1 - forget[not_chosen]) Q_new[not_chosen] <- Q[not_chosen] * (1 - forget[not_chosen])
Q <- Q_new Q <- Q_new
} }
if (return_negLL) { if (return_negLL) {
return(-log_lik) return(-log_lik)
} else { } else {
@ -150,7 +207,6 @@ get_model_configs <- function() {
lower = c(-5, -5, -3), lower = c(-5, -5, -3),
upper = c(5, 5, 3) upper = c(5, 5, 3)
), ),
GAIN_LOSS = list( GAIN_LOSS = list(
name = "GAIN_LOSS", name = "GAIN_LOSS",
n_alpha = 2, n_alpha = 2,
@ -162,7 +218,6 @@ get_model_configs <- function() {
lower = c(-5, -5, -5, -3), lower = c(-5, -5, -5, -3),
upper = c(5, 5, 5, 3) upper = c(5, 5, 5, 3)
), ),
BIASED = list( BIASED = list(
name = "BIASED", name = "BIASED",
n_alpha = 2, n_alpha = 2,
@ -170,13 +225,14 @@ get_model_configs <- function() {
n_lambda = 4, n_lambda = 4,
has_rho = FALSE, has_rho = FALSE,
n_params = 10, n_params = 10,
param_names = c("alpha_loss", "alpha_gain", param_names = c(
"forget_1", "forget_2", "forget_3", "forget_4", "alpha_loss", "alpha_gain",
"lambda_1", "lambda_2", "lambda_3", "lambda_4"), "forget_1", "forget_2", "forget_3", "forget_4",
"lambda_1", "lambda_2", "lambda_3", "lambda_4"
),
lower = c(-5, -5, rep(-5, 4), rep(-3, 4)), lower = c(-5, -5, rep(-5, 4), rep(-3, 4)),
upper = c(5, 5, rep(5, 4), rep(3, 4)) upper = c(5, 5, rep(5, 4), rep(3, 4))
), ),
REE_BIASED_SIMPLE = list( REE_BIASED_SIMPLE = list(
name = "REE_BIASED_SIMPLE", name = "REE_BIASED_SIMPLE",
n_alpha = 2, n_alpha = 2,
@ -184,12 +240,13 @@ get_model_configs <- function() {
n_lambda = 1, n_lambda = 1,
has_rho = TRUE, has_rho = TRUE,
n_params = 6, n_params = 6,
param_names = c("alpha_loss", "alpha_gain", "forget", "lambda", param_names = c(
"rho_BS", "rho_JP"), "alpha_loss", "alpha_gain", "forget", "lambda",
"rho_BS", "rho_JP"
),
lower = c(-5, -5, -5, -3, -10, -10), lower = c(-5, -5, -5, -3, -10, -10),
upper = c(5, 5, 5, 3, 10, 10) upper = c(5, 5, 5, 3, 10, 10)
), ),
REE_BIASED_COMPLEX = list( REE_BIASED_COMPLEX = list(
name = "REE_BIASED_COMPLEX", name = "REE_BIASED_COMPLEX",
n_alpha = 2, n_alpha = 2,
@ -197,14 +254,15 @@ get_model_configs <- function() {
n_lambda = 4, n_lambda = 4,
has_rho = TRUE, has_rho = TRUE,
n_params = 12, n_params = 12,
param_names = c("alpha_loss", "alpha_gain", param_names = c(
"forget_1", "forget_2", "forget_3", "forget_4", "alpha_loss", "alpha_gain",
"lambda_1", "lambda_2", "lambda_3", "lambda_4", "forget_1", "forget_2", "forget_3", "forget_4",
"rho_BS", "rho_JP"), "lambda_1", "lambda_2", "lambda_3", "lambda_4",
"rho_BS", "rho_JP"
),
lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10), lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10) upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10)
), ),
REE_LEARNING_SIMPLE = list( REE_LEARNING_SIMPLE = list(
name = "REE_LEARNING_SIMPLE", name = "REE_LEARNING_SIMPLE",
n_alpha = 4, n_alpha = 4,
@ -212,12 +270,13 @@ get_model_configs <- function() {
n_lambda = 1, n_lambda = 1,
has_rho = FALSE, has_rho = FALSE,
n_params = 6, n_params = 6,
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", param_names = c(
"forget", "lambda"), "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
"forget", "lambda"
),
lower = c(-5, -5, -5, -5, -5, -3), lower = c(-5, -5, -5, -5, -5, -3),
upper = c(5, 5, 5, 5, 5, 3) upper = c(5, 5, 5, 5, 5, 3)
), ),
REE_LEARNING_COMPLEX = list( REE_LEARNING_COMPLEX = list(
name = "REE_LEARNING_COMPLEX", name = "REE_LEARNING_COMPLEX",
n_alpha = 4, n_alpha = 4,
@ -225,13 +284,14 @@ get_model_configs <- function() {
n_lambda = 4, n_lambda = 4,
has_rho = FALSE, has_rho = FALSE,
n_params = 12, n_params = 12,
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", param_names = c(
"forget_1", "forget_2", "forget_3", "forget_4", "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
"lambda_1", "lambda_2", "lambda_3", "lambda_4"), "forget_1", "forget_2", "forget_3", "forget_4",
"lambda_1", "lambda_2", "lambda_3", "lambda_4"
),
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)), lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)),
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4)) upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4))
), ),
REE_LEARNING_BIASED_SIMPLE = list( REE_LEARNING_BIASED_SIMPLE = list(
name = "REE_LEARNING_BIASED_SIMPLE", name = "REE_LEARNING_BIASED_SIMPLE",
n_alpha = 4, n_alpha = 4,
@ -239,12 +299,13 @@ get_model_configs <- function() {
n_lambda = 1, n_lambda = 1,
has_rho = TRUE, has_rho = TRUE,
n_params = 8, n_params = 8,
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", param_names = c(
"forget", "lambda", "rho_BS", "rho_JP"), "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
"forget", "lambda", "rho_BS", "rho_JP"
),
lower = c(-5, -5, -5, -5, -5, -3, -10, -10), lower = c(-5, -5, -5, -5, -5, -3, -10, -10),
upper = c(5, 5, 5, 5, 5, 3, 10, 10) upper = c(5, 5, 5, 5, 5, 3, 10, 10)
), ),
REE_LEARNING_BIASED_COMPLEX = list( REE_LEARNING_BIASED_COMPLEX = list(
name = "REE_LEARNING_BIASED_COMPLEX", name = "REE_LEARNING_BIASED_COMPLEX",
n_alpha = 4, n_alpha = 4,
@ -252,10 +313,12 @@ get_model_configs <- function() {
n_lambda = 4, n_lambda = 4,
has_rho = TRUE, has_rho = TRUE,
n_params = 14, n_params = 14,
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", param_names = c(
"forget_1", "forget_2", "forget_3", "forget_4", "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
"lambda_1", "lambda_2", "lambda_3", "lambda_4", "forget_1", "forget_2", "forget_3", "forget_4",
"rho_BS", "rho_JP"), "lambda_1", "lambda_2", "lambda_3", "lambda_4",
"rho_BS", "rho_JP"
),
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10), lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10) upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10)
) )
@ -267,12 +330,15 @@ get_model_configs <- function() {
# ============================================================================ # ============================================================================
fit_participant <- function(participant_data, model_config, n_runs = 5) { fit_participant <- function(participant_data, model_config, n_runs = 5) {
# Detect presence of rare events for this participant
has_BS_seen <- any(participant_data$button_value == -3000, na.rm = TRUE)
has_JP_seen <- any(participant_data$button_value == 3000, na.rm = TRUE)
all_results <- vector("list", n_runs) all_results <- vector("list", n_runs)
for (run in 1:n_runs) { all_results <- foreach(run = 1:n_runs) %dofuture% {
set.seed(1000 * as.numeric(factor(model_config$name)) + run) set.seed(1000 * as.numeric(factor(model_config$name)) + run)
result <- DEoptim( result <- DEoptim(
fn = qlearning_generic, fn = qlearning_generic,
lower = model_config$lower, lower = model_config$lower,
@ -286,37 +352,41 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
NP = max(50, model_config$n_params * 10) NP = max(50, model_config$n_params * 10)
) )
) )
all_results[[run]] <- list( list(
params = result$optim$bestmem, params = result$optim$bestmem,
negLL = result$optim$bestval negLL = result$optim$bestval
) )
} }
# Sélection du meilleur run # Sélection du meilleur run
all_negLL <- sapply(all_results, function(x) x$negLL) all_negLL <- sapply(all_results, function(x) x$negLL)
best_run <- which.min(all_negLL) best_run <- which.min(all_negLL)
negLL_sd <- sd(all_negLL) negLL_sd <- sd(all_negLL)
negLL_range <- max(all_negLL) - min(all_negLL) negLL_range <- max(all_negLL) - min(all_negLL)
params <- all_results[[best_run]]$params params <- all_results[[best_run]]$params
negLL <- all_results[[best_run]]$negLL negLL <- all_results[[best_run]]$negLL
# Calcul de la Hessienne # Calcul de la Hessienne
hessian_result <- tryCatch({ hessian_result <- tryCatch(
numDeriv::hessian(qlearning_generic, params, {
data = participant_data, numDeriv::hessian(qlearning_generic, params,
model_config = model_config) data = participant_data,
}, error = function(e) NULL) model_config = model_config
)
},
error = function(e) NULL
)
hessian_positive_definite <- FALSE hessian_positive_definite <- FALSE
param_se <- rep(NA, model_config$n_params) param_se <- rep(NA, model_config$n_params)
if (!is.null(hessian_result)) { if (!is.null(hessian_result)) {
eigenvalues <- eigen(hessian_result, only.values = TRUE)$values eigenvalues <- eigen(hessian_result, only.values = TRUE)$values
hessian_positive_definite <- all(eigenvalues > 0) hessian_positive_definite <- all(eigenvalues > 0)
if (hessian_positive_definite) { if (hessian_positive_definite) {
param_vcov <- tryCatch(solve(hessian_result), error = function(e) NULL) param_vcov <- tryCatch(solve(hessian_result), error = function(e) NULL)
if (!is.null(param_vcov)) { if (!is.null(param_vcov)) {
@ -324,7 +394,7 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
} }
} }
} }
# Création du tibble de résultats # Création du tibble de résultats
result_df <- tibble( result_df <- tibble(
model = model_config$name, model = model_config$name,
@ -337,13 +407,17 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
hessian_positive_definite = hessian_positive_definite, hessian_positive_definite = hessian_positive_definite,
converged = negLL_range < 1 converged = negLL_range < 1
) )
# Indicateurs d'événements rares observés (utile pour interprétation des rhos/alphas)
result_df$has_BS_seen <- has_BS_seen
result_df$has_JP_seen <- has_JP_seen
# Ajout des paramètres estimés # Ajout des paramètres estimés
for (i in 1:model_config$n_params) { for (i in 1:model_config$n_params) {
result_df[[model_config$param_names[i]]] <- params[i] result_df[[model_config$param_names[i]]] <- params[i]
result_df[[paste0("se_", model_config$param_names[i])]] <- param_se[i] result_df[[paste0("se_", model_config$param_names[i])]] <- param_se[i]
} }
return(result_df) return(result_df)
} }
@ -352,36 +426,35 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
# ============================================================================ # ============================================================================
fit_all_participants_all_models <- function(data, models_to_fit = NULL) { fit_all_participants_all_models <- function(data, models_to_fit = NULL) {
model_configs <- get_model_configs() model_configs <- get_model_configs()
if (!is.null(models_to_fit)) { if (!is.null(models_to_fit)) {
model_configs <- model_configs[models_to_fit] model_configs <- model_configs[models_to_fit]
} }
participants <- unique(data$participant_id) participants <- unique(data$participant_id)
all_results <- list() all_results <- list()
for (model_name in names(model_configs)) { for (model_name in names(model_configs)) {
cat("\n=== Fitting model:", model_name, "===\n") cat("\n=== Fitting model:", model_name, "===\n")
model_config <- model_configs[[model_name]] model_config <- model_configs[[model_name]]
model_results <- map_df(participants, function(pid) { model_results <- map_df(participants, function(pid) {
cat(" Participant", pid, "\n") cat(" Participant", pid, "\n")
participant_data <- data %>% participant_data <- data %>%
filter(participant_id == pid) %>% filter(participant_id == pid) %>%
arrange(trial) arrange(trial)
fit <- fit_participant(participant_data, model_config) fit <- fit_participant(participant_data, model_config)
fit %>% mutate(participant_id = pid, .before = 1) fit %>% mutate(participant_id = pid, .before = 1)
}) })
all_results[[model_name]] <- model_results all_results[[model_name]] <- model_results
} }
return(all_results) return(all_results)
} }
@ -390,11 +463,10 @@ fit_all_participants_all_models <- function(data, models_to_fit = NULL) {
# ============================================================================ # ============================================================================
compare_nested_models <- function(all_results) { compare_nested_models <- function(all_results) {
# Comparaison globale # Comparaison globale
global_comparison <- map_df(names(all_results), function(model_name) { global_comparison <- map_df(names(all_results), function(model_name) {
results <- all_results[[model_name]] results <- all_results[[model_name]]
tibble( tibble(
model = model_name, model = model_name,
n_params = unique(results$n_params), n_params = unique(results$n_params),
@ -407,10 +479,10 @@ compare_nested_models <- function(all_results) {
) )
}) %>% }) %>%
arrange(total_BIC) arrange(total_BIC)
cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n") cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n")
print(global_comparison) print(global_comparison)
# Meilleur modèle par participant (BIC) # Meilleur modèle par participant (BIC)
best_models_per_participant <- map_df(names(all_results), function(model_name) { best_models_per_participant <- map_df(names(all_results), function(model_name) {
all_results[[model_name]] %>% all_results[[model_name]] %>%
@ -419,13 +491,13 @@ compare_nested_models <- function(all_results) {
group_by(participant_id) %>% group_by(participant_id) %>%
slice_min(BIC, n = 1) %>% slice_min(BIC, n = 1) %>%
ungroup() ungroup()
cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n") cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n")
print(table(best_models_per_participant$model)) print(table(best_models_per_participant$model))
# Comparaison par paires de modèles emboîtés # Comparaison par paires de modèles emboîtés
cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n") cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n")
# Exemples de paires emboîtées # Exemples de paires emboîtées
nested_pairs <- list( nested_pairs <- list(
c("HOMOGENEOUS", "GAIN_LOSS"), c("HOMOGENEOUS", "GAIN_LOSS"),
@ -437,15 +509,15 @@ compare_nested_models <- function(all_results) {
c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"), c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"),
c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX") c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX")
) )
lrt_results <- map_df(nested_pairs, function(pair) { lrt_results <- map_df(nested_pairs, function(pair) {
simple_model <- pair[1] simple_model <- pair[1]
complex_model <- pair[2] complex_model <- pair[2]
if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) { if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) {
simple_res <- all_results[[simple_model]] simple_res <- all_results[[simple_model]]
complex_res <- all_results[[complex_model]] complex_res <- all_results[[complex_model]]
# LRT par participant # LRT par participant
lrt_df <- simple_res %>% lrt_df <- simple_res %>%
select(participant_id, negLL_simple = negLL, converged_simple = converged) %>% select(participant_id, negLL_simple = negLL, converged_simple = converged) %>%
@ -460,7 +532,7 @@ compare_nested_models <- function(all_results) {
p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE), p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE),
significant = p_value < 0.05 significant = p_value < 0.05
) )
tibble( tibble(
simple_model = simple_model, simple_model = simple_model,
complex_model = complex_model, complex_model = complex_model,
@ -471,9 +543,9 @@ compare_nested_models <- function(all_results) {
) )
} }
}) })
print(lrt_results) print(lrt_results)
list( list(
global_comparison = global_comparison, global_comparison = global_comparison,
best_models_per_participant = best_models_per_participant, best_models_per_participant = best_models_per_participant,
@ -489,7 +561,7 @@ compare_nested_models <- function(all_results) {
plot_model_comparison <- function(comparison_results) { plot_model_comparison <- function(comparison_results) {
require(ggplot2) require(ggplot2)
require(patchwork) require(patchwork)
# 1. Comparaison BIC globale # 1. Comparaison BIC globale
p1 <- comparison_results$global_comparison %>% p1 <- comparison_results$global_comparison %>%
mutate(model = fct_reorder(model, total_BIC)) %>% mutate(model = fct_reorder(model, total_BIC)) %>%
@ -497,10 +569,12 @@ plot_model_comparison <- function(comparison_results) {
geom_col() + geom_col() +
coord_flip() + coord_flip() +
theme_minimal() + theme_minimal() +
labs(title = "Comparaison globale des modèles (BIC)", labs(
subtitle = "Plus bas = meilleur") + title = "Comparaison globale des modèles (BIC)",
subtitle = "Plus bas = meilleur"
) +
theme(legend.position = "none") theme(legend.position = "none")
# 2. Meilleur modèle par participant # 2. Meilleur modèle par participant
p2 <- comparison_results$best_models_per_participant %>% p2 <- comparison_results$best_models_per_participant %>%
count(model) %>% count(model) %>%
@ -509,27 +583,34 @@ plot_model_comparison <- function(comparison_results) {
geom_col() + geom_col() +
coord_flip() + coord_flip() +
theme_minimal() + theme_minimal() +
labs(title = "Meilleur modèle par participant", labs(
y = "Nombre de participants") + title = "Meilleur modèle par participant",
y = "Nombre de participants"
) +
theme(legend.position = "none") theme(legend.position = "none")
# 3. Tests LRT # 3. Tests LRT
if (nrow(comparison_results$lrt_results) > 0) { if (nrow(comparison_results$lrt_results) > 0) {
p3 <- comparison_results$lrt_results %>% p3 <- comparison_results$lrt_results %>%
mutate(comparison = paste(simple_model, "→", complex_model)) %>% mutate(comparison = paste(simple_model, "→", complex_model)) %>%
ggplot(aes(x = fct_reorder(comparison, pct_significant), ggplot(aes(
y = pct_significant)) + x = fct_reorder(comparison, pct_significant),
y = pct_significant
)) +
geom_col(fill = "steelblue") + geom_col(fill = "steelblue") +
geom_hline(yintercept = 50, linetype = "dashed", color = "red") + geom_hline(yintercept = 50, linetype = "dashed", color = "red") +
coord_flip() + coord_flip() +
theme_minimal() + theme_minimal() +
labs(title = "Tests LRT entre modèles emboîtés", labs(
y = "% participants avec p < 0.05", title = "Tests LRT entre modèles emboîtés",
x = "Comparaison") y = "% participants avec p < 0.05",
x = "Comparaison"
)
} else { } else {
p3 <- ggplot() + theme_void() p3 <- ggplot() +
theme_void()
} }
# 4. Convergence par modèle # 4. Convergence par modèle
p4 <- map_df(names(comparison_results$all_results), function(model_name) { p4 <- map_df(names(comparison_results$all_results), function(model_name) {
comparison_results$all_results[[model_name]] %>% comparison_results$all_results[[model_name]] %>%
@ -540,9 +621,11 @@ plot_model_comparison <- function(comparison_results) {
scale_y_log10() + scale_y_log10() +
coord_flip() + coord_flip() +
theme_minimal() + theme_minimal() +
labs(title = "Convergence par modèle", labs(
y = "Range negLL (log scale)") title = "Convergence par modèle",
y = "Range negLL (log scale)"
)
(p1 | p2) / (p3 | p4) (p1 | p2) / (p3 | p4)
} }
@ -553,14 +636,15 @@ plot_model_comparison <- function(comparison_results) {
# Charger vos données # Charger vos données
# data <- read_csv("votre_fichier.csv") # data <- read_csv("votre_fichier.csv")
# Colonnes requises: participant_id, trial, choice, reward # Colonnes requises: participant_id, trial, choice, reward
source("load_data.R")
# Estimation de tous les modèles # Estimation de tous les modèles
# all_results <- fit_all_participants_all_models(data) # all_results <- fit_all_participants_all_models(data)
fit_all_participants_all_models(data %>% filter(participant_id == "qfmtmjjy"))
# Ou seulement certains modèles # Ou seulement certains modèles
# all_results <- fit_all_participants_all_models( # all_results <- fit_all_participants_all_models(
# data, # data,
# models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE", # models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE",
# "REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE") # "REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE")
# ) # )
@ -572,8 +656,8 @@ plot_model_comparison <- function(comparison_results) {
# Sauvegarder les résultats # Sauvegarder les résultats
# for (model_name in names(all_results)) { # for (model_name in names(all_results)) {
# write_csv(all_results[[model_name]], # write_csv(all_results[[model_name]],
# paste0("results_", model_name, ".csv")) # paste0("results_", model_name, ".csv"))
# } # }
# write_csv(comparison$global_comparison, "global_comparison.csv") # write_csv(comparison$global_comparison, "global_comparison.csv")
# write_csv(comparison$best_models_per_participant, "best_models.csv") # write_csv(comparison$best_models_per_participant, "best_models.csv")