diff --git a/modelling V4.R b/modelling V4.R index 3901990..bf6a6b8 100644 --- a/modelling V4.R +++ b/modelling V4.R @@ -5,27 +5,60 @@ library(tidyverse) library(DEoptim) library(numDeriv) +# foreach future +library(foreach) +library(doFuture) +library(future.callr) + +plan(callr, workers = future::availableCores(omit = 1L)) # ============================================================================ # FONCTION GÉNÉRIQUE DE Q-LEARNING # ============================================================================ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { - - # Conversion des choix en indices numériques - if (is.factor(data$button_name) || is.character(data$button_name)) { + # Normalise noms de colonnes et conversion des choix en indices numériques + if (!("button_value" %in% names(data)) && ("reward" %in% names(data))) { + data$button_value <- data$reward + } + if (!("button_name" %in% names(data))) { + if ("option" %in% names(data)) { + data$button_name <- data$option + } else if (("choice" %in% names(data)) && is.character(data$choice)) { + data$button_name <- data$choice + } + } + + # If 'choice' is numeric (prepared data), use it directly + if (("choice" %in% names(data)) && is.numeric(data$choice)) { + data$choice_idx <- data$choice + } else if (is.factor(data$button_name) || is.character(data$button_name)) { choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable") data$choice_idx <- match(as.character(data$button_name), choice_levels) - } else { + } else if ("button_name" %in% names(data)) { data$choice_idx <- data$button_name } - + + # Robustness: if mapping produced NAs, try remapping or fail fast with large penalty + if (any(is.na(data$choice_idx))) { + known_levels <- c("antifragile", "fragile", "robuste", "vulnerable") + if (!any(is.na(match(as.character(data$button_name), known_levels)))) { + data$choice_idx <- match(as.character(data$button_name), known_levels) + } else { + if (return_negLL) { + return(1e6) + } else { + return(-1e6) + } + } + } + n_arms <- 4 n_trials <- nrow(data) - + # Extraction des paramètres selon la configuration du modèle param_idx <- 1 - + # ALPHA(S) if (model_config$n_alpha == 1) { alpha_loss <- alpha_gain <- alpha_BS <- alpha_JP <- plogis(params[param_idx]) @@ -43,7 +76,7 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { alpha_JP <- plogis(params[param_idx + 3]) param_idx <- param_idx + 4 } - + # FORGET(S) if (model_config$n_forget == 1) { forget <- rep(plogis(params[param_idx]), n_arms) @@ -52,7 +85,7 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { forget <- plogis(params[param_idx:(param_idx + 3)]) param_idx <- param_idx + 4 } - + # LAMBDA(S) if (model_config$n_lambda == 1) { lambda <- rep(exp(params[param_idx]), n_arms) @@ -61,52 +94,76 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { lambda <- exp(params[param_idx:(param_idx + 3)]) param_idx <- param_idx + 4 } - + # RHO(S) - Biais pour événements rares if (model_config$has_rho) { - rho_BS <- params[param_idx] # BS avoidance - rho_JP <- params[param_idx + 1] # JP seeking + rho_BS <- params[param_idx] # BS avoidance + rho_JP <- params[param_idx + 1] # JP seeking param_idx <- param_idx + 2 } else { rho_BS <- rho_JP <- 0 } - + + # Detect if rare events actually occur in this participant's data + has_BS_seen <- any(data$button_value == -3000, na.rm = TRUE) + has_JP_seen <- any(data$button_value == 3000, na.rm = TRUE) + + # If an REE type was never observed, neutralize its rho to avoid non-identifiability + if (model_config$has_rho) { + if (!has_BS_seen) rho_BS <- 0 + if (!has_JP_seen) rho_JP <- 0 + } + # Initialisation des Q-values Q <- rep(0, n_arms) log_lik <- 0 - + for (t in 1:n_trials) { choice <- data$choice_idx[t] reward <- data$button_value[t] - + # Calcul des valeurs subjectives V(t) V <- lambda * Q - + # Ajout des biais pour événements rares si le modèle le permet if (model_config$has_rho) { # Identification des options susceptibles de produire BS/JP # antifragile (1) = JP possible, fragile (2) = BS possible # vulnerable (4) = BS et JP possibles - V[1] <- V[1] + rho_JP # antifragile - V[2] <- V[2] + rho_BS # fragile - V[4] <- V[4] + rho_BS + rho_JP # vulnerable + V[1] <- V[1] + rho_JP # antifragile + V[2] <- V[2] + rho_BS # fragile + V[4] <- V[4] + rho_BS + rho_JP # vulnerable } - + # Softmax V_max <- max(V) exp_V <- exp(V - V_max) probs <- exp_V / sum(exp_V) probs <- pmax(probs, 1e-10) probs <- probs / sum(probs) - + # Log-likelihood log_lik <- log_lik + log(probs[choice]) - + # Mise à jour Q-learning Q_new <- Q - + # Choix de l'alpha approprié - if (reward == -3000) { + # if (reward == -3000) { + # alpha_used <- alpha_BS + # } else if (reward == 3000) { + # alpha_used <- alpha_JP + # } else if (reward < 0) { + # alpha_used <- alpha_loss + # } else { + # alpha_used <- alpha_gain + # } + + # Fix when there are no extreme rewards while taking them into account + if (is.na(reward)) { + # skip trials with missing reward + next + } else if (reward == -3000) { alpha_used <- alpha_BS } else if (reward == 3000) { alpha_used <- alpha_JP @@ -115,17 +172,17 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { } else { alpha_used <- alpha_gain } - + # Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t)) Q_new[choice] <- Q[choice] + alpha_used * (reward - Q[choice]) - + # Options non choisies : Q(t+1) = Q(t) * (1 - f) not_chosen <- setdiff(1:n_arms, choice) Q_new[not_chosen] <- Q[not_chosen] * (1 - forget[not_chosen]) - + Q <- Q_new } - + if (return_negLL) { return(-log_lik) } else { @@ -150,7 +207,6 @@ get_model_configs <- function() { lower = c(-5, -5, -3), upper = c(5, 5, 3) ), - GAIN_LOSS = list( name = "GAIN_LOSS", n_alpha = 2, @@ -162,7 +218,6 @@ get_model_configs <- function() { lower = c(-5, -5, -5, -3), upper = c(5, 5, 5, 3) ), - BIASED = list( name = "BIASED", n_alpha = 2, @@ -170,13 +225,14 @@ get_model_configs <- function() { n_lambda = 4, has_rho = FALSE, n_params = 10, - param_names = c("alpha_loss", "alpha_gain", - "forget_1", "forget_2", "forget_3", "forget_4", - "lambda_1", "lambda_2", "lambda_3", "lambda_4"), + param_names = c( + "alpha_loss", "alpha_gain", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4" + ), lower = c(-5, -5, rep(-5, 4), rep(-3, 4)), upper = c(5, 5, rep(5, 4), rep(3, 4)) ), - REE_BIASED_SIMPLE = list( name = "REE_BIASED_SIMPLE", n_alpha = 2, @@ -184,12 +240,13 @@ get_model_configs <- function() { n_lambda = 1, has_rho = TRUE, n_params = 6, - param_names = c("alpha_loss", "alpha_gain", "forget", "lambda", - "rho_BS", "rho_JP"), + param_names = c( + "alpha_loss", "alpha_gain", "forget", "lambda", + "rho_BS", "rho_JP" + ), lower = c(-5, -5, -5, -3, -10, -10), upper = c(5, 5, 5, 3, 10, 10) ), - REE_BIASED_COMPLEX = list( name = "REE_BIASED_COMPLEX", n_alpha = 2, @@ -197,14 +254,15 @@ get_model_configs <- function() { n_lambda = 4, has_rho = TRUE, n_params = 12, - param_names = c("alpha_loss", "alpha_gain", - "forget_1", "forget_2", "forget_3", "forget_4", - "lambda_1", "lambda_2", "lambda_3", "lambda_4", - "rho_BS", "rho_JP"), + param_names = c( + "alpha_loss", "alpha_gain", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4", + "rho_BS", "rho_JP" + ), lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10), upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10) ), - REE_LEARNING_SIMPLE = list( name = "REE_LEARNING_SIMPLE", n_alpha = 4, @@ -212,12 +270,13 @@ get_model_configs <- function() { n_lambda = 1, has_rho = FALSE, n_params = 6, - param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", - "forget", "lambda"), + param_names = c( + "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget", "lambda" + ), lower = c(-5, -5, -5, -5, -5, -3), upper = c(5, 5, 5, 5, 5, 3) ), - REE_LEARNING_COMPLEX = list( name = "REE_LEARNING_COMPLEX", n_alpha = 4, @@ -225,13 +284,14 @@ get_model_configs <- function() { n_lambda = 4, has_rho = FALSE, n_params = 12, - param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", - "forget_1", "forget_2", "forget_3", "forget_4", - "lambda_1", "lambda_2", "lambda_3", "lambda_4"), + param_names = c( + "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4" + ), lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)), upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4)) ), - REE_LEARNING_BIASED_SIMPLE = list( name = "REE_LEARNING_BIASED_SIMPLE", n_alpha = 4, @@ -239,12 +299,13 @@ get_model_configs <- function() { n_lambda = 1, has_rho = TRUE, n_params = 8, - param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", - "forget", "lambda", "rho_BS", "rho_JP"), + param_names = c( + "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget", "lambda", "rho_BS", "rho_JP" + ), lower = c(-5, -5, -5, -5, -5, -3, -10, -10), upper = c(5, 5, 5, 5, 5, 3, 10, 10) ), - REE_LEARNING_BIASED_COMPLEX = list( name = "REE_LEARNING_BIASED_COMPLEX", n_alpha = 4, @@ -252,10 +313,12 @@ get_model_configs <- function() { n_lambda = 4, has_rho = TRUE, n_params = 14, - param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", - "forget_1", "forget_2", "forget_3", "forget_4", - "lambda_1", "lambda_2", "lambda_3", "lambda_4", - "rho_BS", "rho_JP"), + param_names = c( + "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4", + "rho_BS", "rho_JP" + ), lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10), upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10) ) @@ -267,12 +330,15 @@ get_model_configs <- function() { # ============================================================================ fit_participant <- function(participant_data, model_config, n_runs = 5) { - + # Detect presence of rare events for this participant + has_BS_seen <- any(participant_data$button_value == -3000, na.rm = TRUE) + has_JP_seen <- any(participant_data$button_value == 3000, na.rm = TRUE) + all_results <- vector("list", n_runs) - - for (run in 1:n_runs) { + + all_results <- foreach(run = 1:n_runs) %dofuture% { set.seed(1000 * as.numeric(factor(model_config$name)) + run) - + result <- DEoptim( fn = qlearning_generic, lower = model_config$lower, @@ -286,37 +352,41 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) { NP = max(50, model_config$n_params * 10) ) ) - - all_results[[run]] <- list( + + list( params = result$optim$bestmem, negLL = result$optim$bestval ) } - + # Sélection du meilleur run all_negLL <- sapply(all_results, function(x) x$negLL) best_run <- which.min(all_negLL) - + negLL_sd <- sd(all_negLL) negLL_range <- max(all_negLL) - min(all_negLL) - + params <- all_results[[best_run]]$params negLL <- all_results[[best_run]]$negLL - + # Calcul de la Hessienne - hessian_result <- tryCatch({ - numDeriv::hessian(qlearning_generic, params, - data = participant_data, - model_config = model_config) - }, error = function(e) NULL) - + hessian_result <- tryCatch( + { + numDeriv::hessian(qlearning_generic, params, + data = participant_data, + model_config = model_config + ) + }, + error = function(e) NULL + ) + hessian_positive_definite <- FALSE param_se <- rep(NA, model_config$n_params) - + if (!is.null(hessian_result)) { eigenvalues <- eigen(hessian_result, only.values = TRUE)$values hessian_positive_definite <- all(eigenvalues > 0) - + if (hessian_positive_definite) { param_vcov <- tryCatch(solve(hessian_result), error = function(e) NULL) if (!is.null(param_vcov)) { @@ -324,7 +394,7 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) { } } } - + # Création du tibble de résultats result_df <- tibble( model = model_config$name, @@ -337,13 +407,17 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) { hessian_positive_definite = hessian_positive_definite, converged = negLL_range < 1 ) - + + # Indicateurs d'événements rares observés (utile pour interprétation des rhos/alphas) + result_df$has_BS_seen <- has_BS_seen + result_df$has_JP_seen <- has_JP_seen + # Ajout des paramètres estimés for (i in 1:model_config$n_params) { result_df[[model_config$param_names[i]]] <- params[i] result_df[[paste0("se_", model_config$param_names[i])]] <- param_se[i] } - + return(result_df) } @@ -352,36 +426,35 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) { # ============================================================================ fit_all_participants_all_models <- function(data, models_to_fit = NULL) { - model_configs <- get_model_configs() - + if (!is.null(models_to_fit)) { model_configs <- model_configs[models_to_fit] } - + participants <- unique(data$participant_id) - + all_results <- list() - + for (model_name in names(model_configs)) { cat("\n=== Fitting model:", model_name, "===\n") - + model_config <- model_configs[[model_name]] - + model_results <- map_df(participants, function(pid) { cat(" Participant", pid, "\n") - - participant_data <- data %>% + + participant_data <- data %>% filter(participant_id == pid) %>% arrange(trial) - + fit <- fit_participant(participant_data, model_config) fit %>% mutate(participant_id = pid, .before = 1) }) - + all_results[[model_name]] <- model_results } - + return(all_results) } @@ -390,11 +463,10 @@ fit_all_participants_all_models <- function(data, models_to_fit = NULL) { # ============================================================================ compare_nested_models <- function(all_results) { - # Comparaison globale global_comparison <- map_df(names(all_results), function(model_name) { results <- all_results[[model_name]] - + tibble( model = model_name, n_params = unique(results$n_params), @@ -407,10 +479,10 @@ compare_nested_models <- function(all_results) { ) }) %>% arrange(total_BIC) - + cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n") print(global_comparison) - + # Meilleur modèle par participant (BIC) best_models_per_participant <- map_df(names(all_results), function(model_name) { all_results[[model_name]] %>% @@ -419,13 +491,13 @@ compare_nested_models <- function(all_results) { group_by(participant_id) %>% slice_min(BIC, n = 1) %>% ungroup() - + cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n") print(table(best_models_per_participant$model)) - + # Comparaison par paires de modèles emboîtés cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n") - + # Exemples de paires emboîtées nested_pairs <- list( c("HOMOGENEOUS", "GAIN_LOSS"), @@ -437,15 +509,15 @@ compare_nested_models <- function(all_results) { c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"), c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX") ) - + lrt_results <- map_df(nested_pairs, function(pair) { simple_model <- pair[1] complex_model <- pair[2] - + if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) { simple_res <- all_results[[simple_model]] complex_res <- all_results[[complex_model]] - + # LRT par participant lrt_df <- simple_res %>% select(participant_id, negLL_simple = negLL, converged_simple = converged) %>% @@ -460,7 +532,7 @@ compare_nested_models <- function(all_results) { p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE), significant = p_value < 0.05 ) - + tibble( simple_model = simple_model, complex_model = complex_model, @@ -471,9 +543,9 @@ compare_nested_models <- function(all_results) { ) } }) - + print(lrt_results) - + list( global_comparison = global_comparison, best_models_per_participant = best_models_per_participant, @@ -489,7 +561,7 @@ compare_nested_models <- function(all_results) { plot_model_comparison <- function(comparison_results) { require(ggplot2) require(patchwork) - + # 1. Comparaison BIC globale p1 <- comparison_results$global_comparison %>% mutate(model = fct_reorder(model, total_BIC)) %>% @@ -497,10 +569,12 @@ plot_model_comparison <- function(comparison_results) { geom_col() + coord_flip() + theme_minimal() + - labs(title = "Comparaison globale des modèles (BIC)", - subtitle = "Plus bas = meilleur") + + labs( + title = "Comparaison globale des modèles (BIC)", + subtitle = "Plus bas = meilleur" + ) + theme(legend.position = "none") - + # 2. Meilleur modèle par participant p2 <- comparison_results$best_models_per_participant %>% count(model) %>% @@ -509,27 +583,34 @@ plot_model_comparison <- function(comparison_results) { geom_col() + coord_flip() + theme_minimal() + - labs(title = "Meilleur modèle par participant", - y = "Nombre de participants") + + labs( + title = "Meilleur modèle par participant", + y = "Nombre de participants" + ) + theme(legend.position = "none") - + # 3. Tests LRT if (nrow(comparison_results$lrt_results) > 0) { p3 <- comparison_results$lrt_results %>% mutate(comparison = paste(simple_model, "→", complex_model)) %>% - ggplot(aes(x = fct_reorder(comparison, pct_significant), - y = pct_significant)) + + ggplot(aes( + x = fct_reorder(comparison, pct_significant), + y = pct_significant + )) + geom_col(fill = "steelblue") + geom_hline(yintercept = 50, linetype = "dashed", color = "red") + coord_flip() + theme_minimal() + - labs(title = "Tests LRT entre modèles emboîtés", - y = "% participants avec p < 0.05", - x = "Comparaison") + labs( + title = "Tests LRT entre modèles emboîtés", + y = "% participants avec p < 0.05", + x = "Comparaison" + ) } else { - p3 <- ggplot() + theme_void() + p3 <- ggplot() + + theme_void() } - + # 4. Convergence par modèle p4 <- map_df(names(comparison_results$all_results), function(model_name) { comparison_results$all_results[[model_name]] %>% @@ -540,9 +621,11 @@ plot_model_comparison <- function(comparison_results) { scale_y_log10() + coord_flip() + theme_minimal() + - labs(title = "Convergence par modèle", - y = "Range negLL (log scale)") - + labs( + title = "Convergence par modèle", + y = "Range negLL (log scale)" + ) + (p1 | p2) / (p3 | p4) } @@ -553,14 +636,15 @@ plot_model_comparison <- function(comparison_results) { # Charger vos données # data <- read_csv("votre_fichier.csv") # Colonnes requises: participant_id, trial, choice, reward +source("load_data.R") # Estimation de tous les modèles # all_results <- fit_all_participants_all_models(data) - +fit_all_participants_all_models(data %>% filter(participant_id == "qfmtmjjy")) # Ou seulement certains modèles # all_results <- fit_all_participants_all_models( -# data, -# models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE", +# data, +# models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE", # "REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE") # ) @@ -572,8 +656,8 @@ plot_model_comparison <- function(comparison_results) { # Sauvegarder les résultats # for (model_name in names(all_results)) { -# write_csv(all_results[[model_name]], +# write_csv(all_results[[model_name]], # paste0("results_", model_name, ".csv")) # } # write_csv(comparison$global_comparison, "global_comparison.csv") -# write_csv(comparison$best_models_per_participant, "best_models.csv") \ No newline at end of file +# write_csv(comparison$best_models_per_participant, "best_models.csv")