From bc7a8419f9cb0998058aae04ebad3abb45c60dd0 Mon Sep 17 00:00:00 2001 From: Louis Date: Sun, 7 Dec 2025 11:39:19 +0100 Subject: [PATCH] Various fixes for V4 --- modelling R + PyVBMC V5.R | 1235 ++++++++++++++++++++++--------------- 1 file changed, 745 insertions(+), 490 deletions(-) diff --git a/modelling R + PyVBMC V5.R b/modelling R + PyVBMC V5.R index 97076e5..8e00ddd 100644 --- a/modelling R + PyVBMC V5.R +++ b/modelling R + PyVBMC V5.R @@ -5,18 +5,20 @@ library(tidyverse) library(DEoptim) library(numDeriv) -library(reticulate) # Pour interfacer avec PyVBMC +library(future) +library(doFuture) +library(reticulate) # Pour interfacer avec PyVBMC -# Configuration optionnelle de l'environnement Python -# use_python("/usr/bin/python3") # Ajuster selon votre installation -# py_install("pyvbmc") # Installer PyVBMC si nécessaire +# Configuration de l'environnement Python +# Important : Assurez-vous que PyVBMC est installé dans cet environnement +# use_python("/chemin/vers/python") # Optionnel: spécifier l'installation Python +# reticulate::py_install("pyvbmc") # Décommenter pour installer PyVBMC si nécessaire # ============================================================================ # FONCTION GÉNÉRIQUE DE Q-LEARNING # ============================================================================ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { - # Conversion des choix en indices numériques if (is.factor(data$choice) || is.character(data$choice)) { choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable") @@ -24,13 +26,13 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { } else { data$choice_idx <- data$choice } - + n_arms <- 4 n_trials <- nrow(data) - + # Extraction des paramètres selon la configuration du modèle param_idx <- 1 - + # ALPHA(S) if (model_config$n_alpha == 1) { alpha_loss <- alpha_gain <- alpha_BS <- alpha_JP <- plogis(params[param_idx]) @@ -48,7 +50,7 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { alpha_JP <- plogis(params[param_idx + 3]) param_idx <- param_idx + 4 } - + # FORGET(S) if (model_config$n_forget == 1) { forget <- rep(plogis(params[param_idx]), n_arms) @@ -57,7 +59,7 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { forget <- plogis(params[param_idx:(param_idx + 3)]) param_idx <- param_idx + 4 } - + # LAMBDA(S) if (model_config$n_lambda == 1) { lambda <- rep(exp(params[param_idx]), n_arms) @@ -66,50 +68,50 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { lambda <- exp(params[param_idx:(param_idx + 3)]) param_idx <- param_idx + 4 } - + # RHO(S) - Biais pour événements rares if (model_config$has_rho) { - rho_BS <- params[param_idx] # BS avoidance - rho_JP <- params[param_idx + 1] # JP seeking + rho_BS <- params[param_idx] # BS avoidance + rho_JP <- params[param_idx + 1] # JP seeking param_idx <- param_idx + 2 } else { rho_BS <- rho_JP <- 0 } - + # Initialisation des Q-values Q <- rep(0, n_arms) log_lik <- 0 - + for (t in 1:n_trials) { choice <- data$choice_idx[t] reward <- data$reward[t] - + # Calcul des valeurs subjectives V(t) V <- lambda * Q - + # Ajout des biais pour événements rares si le modèle le permet if (model_config$has_rho) { # Identification des options susceptibles de produire BS/JP # antifragile (1) = JP possible, fragile (2) = BS possible # vulnerable (4) = BS et JP possibles - V[1] <- V[1] + rho_JP # antifragile - V[2] <- V[2] + rho_BS # fragile - V[4] <- V[4] + rho_BS + rho_JP # vulnerable + V[1] <- V[1] + rho_JP # antifragile + V[2] <- V[2] + rho_BS # fragile + V[4] <- V[4] + rho_BS + rho_JP # vulnerable } - + # Softmax V_max <- max(V) exp_V <- exp(V - V_max) probs <- exp_V / sum(exp_V) probs <- pmax(probs, 1e-10) probs <- probs / sum(probs) - + # Log-likelihood log_lik <- log_lik + log(probs[choice]) - + # Mise à jour Q-learning Q_new <- Q - + # Choix de l'alpha approprié if (reward == -3000) { alpha_used <- alpha_BS @@ -120,17 +122,17 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) { } else { alpha_used <- alpha_gain } - + # Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t)) Q_new[choice] <- Q[choice] + alpha_used * (reward - Q[choice]) - + # Options non choisies : Q(t+1) = Q(t) * (1 - f) not_chosen <- setdiff(1:n_arms, choice) Q_new[not_chosen] <- Q[not_chosen] * (1 - forget[not_chosen]) - + Q <- Q_new } - + if (return_negLL) { return(-log_lik) } else { @@ -155,7 +157,6 @@ get_model_configs <- function() { lower = c(-5, -5, -3), upper = c(5, 5, 3) ), - GAIN_LOSS = list( name = "GAIN_LOSS", n_alpha = 2, @@ -167,7 +168,6 @@ get_model_configs <- function() { lower = c(-5, -5, -5, -3), upper = c(5, 5, 5, 3) ), - BIASED = list( name = "BIASED", n_alpha = 2, @@ -175,13 +175,14 @@ get_model_configs <- function() { n_lambda = 4, has_rho = FALSE, n_params = 10, - param_names = c("alpha_loss", "alpha_gain", - "forget_1", "forget_2", "forget_3", "forget_4", - "lambda_1", "lambda_2", "lambda_3", "lambda_4"), + param_names = c( + "alpha_loss", "alpha_gain", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4" + ), lower = c(-5, -5, rep(-5, 4), rep(-3, 4)), upper = c(5, 5, rep(5, 4), rep(3, 4)) ), - REE_BIASED_SIMPLE = list( name = "REE_BIASED_SIMPLE", n_alpha = 2, @@ -189,12 +190,13 @@ get_model_configs <- function() { n_lambda = 1, has_rho = TRUE, n_params = 6, - param_names = c("alpha_loss", "alpha_gain", "forget", "lambda", - "rho_BS", "rho_JP"), + param_names = c( + "alpha_loss", "alpha_gain", "forget", "lambda", + "rho_BS", "rho_JP" + ), lower = c(-5, -5, -5, -3, -10, -10), upper = c(5, 5, 5, 3, 10, 10) ), - REE_BIASED_COMPLEX = list( name = "REE_BIASED_COMPLEX", n_alpha = 2, @@ -202,14 +204,15 @@ get_model_configs <- function() { n_lambda = 4, has_rho = TRUE, n_params = 12, - param_names = c("alpha_loss", "alpha_gain", - "forget_1", "forget_2", "forget_3", "forget_4", - "lambda_1", "lambda_2", "lambda_3", "lambda_4", - "rho_BS", "rho_JP"), + param_names = c( + "alpha_loss", "alpha_gain", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4", + "rho_BS", "rho_JP" + ), lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10), upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10) ), - REE_LEARNING_SIMPLE = list( name = "REE_LEARNING_SIMPLE", n_alpha = 4, @@ -217,12 +220,13 @@ get_model_configs <- function() { n_lambda = 1, has_rho = FALSE, n_params = 6, - param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", - "forget", "lambda"), + param_names = c( + "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget", "lambda" + ), lower = c(-5, -5, -5, -5, -5, -3), upper = c(5, 5, 5, 5, 5, 3) ), - REE_LEARNING_COMPLEX = list( name = "REE_LEARNING_COMPLEX", n_alpha = 4, @@ -230,13 +234,14 @@ get_model_configs <- function() { n_lambda = 4, has_rho = FALSE, n_params = 12, - param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", - "forget_1", "forget_2", "forget_3", "forget_4", - "lambda_1", "lambda_2", "lambda_3", "lambda_4"), + param_names = c( + "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4" + ), lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)), upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4)) ), - REE_LEARNING_BIASED_SIMPLE = list( name = "REE_LEARNING_BIASED_SIMPLE", n_alpha = 4, @@ -244,12 +249,13 @@ get_model_configs <- function() { n_lambda = 1, has_rho = TRUE, n_params = 8, - param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", - "forget", "lambda", "rho_BS", "rho_JP"), + param_names = c( + "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget", "lambda", "rho_BS", "rho_JP" + ), lower = c(-5, -5, -5, -5, -5, -3, -10, -10), upper = c(5, 5, 5, 5, 5, 3, 10, 10) ), - REE_LEARNING_BIASED_COMPLEX = list( name = "REE_LEARNING_BIASED_COMPLEX", n_alpha = 4, @@ -257,375 +263,28 @@ get_model_configs <- function() { n_lambda = 4, has_rho = TRUE, n_params = 14, - param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", - "forget_1", "forget_2", "forget_3", "forget_4", - "lambda_1", "lambda_2", "lambda_3", "lambda_4", - "rho_BS", "rho_JP"), + param_names = c( + "alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP", + "forget_1", "forget_2", "forget_3", "forget_4", + "lambda_1", "lambda_2", "lambda_3", "lambda_4", + "rho_BS", "rho_JP" + ), lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10), upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10) ) ) } -# ============================================================================ -# ESTIMATION AVEC VBMC (MÉTHODE BAYÉSIENNE) -# ============================================================================ +fit_participant <- function(participant_data, model_config, n_runs = 5) { + # Detect presence of rare events for this participant + has_BS_seen <- any(participant_data$button_value == -3000, na.rm = TRUE) + has_JP_seen <- any(participant_data$button_value == 3000, na.rm = TRUE) -fit_participant_vbmc <- function(participant_data, model_config, use_python_vbmc = TRUE) { - - if (!use_python_vbmc) { - warning("VBMC nécessite Python. Utilisation de DEoptim à la place.") - return(fit_participant(participant_data, model_config)) - } - - # Import de PyVBMC - tryCatch({ - pyvbmc <- import("pyvbmc") - }, error = function(e) { - stop("PyVBMC n'est pas installé. Installez-le avec: py_install('pyvbmc')") - }) - - # Wrapper pour la fonction log-posterior - log_posterior_wrapper <- function(params_array) { - # PyVBMC passe un array numpy, on le convertit en vecteur R - params_vec <- as.vector(params_array) - - # Retourne la log-vraisemblance (négatif de negLL) - negLL <- qlearning_generic(params_vec, participant_data, model_config, return_negLL = TRUE) - return(-negLL) # VBMC maximise, donc on retourne -negLL - } - - # Point de départ (milieu des bornes plausibles) - x0 <- (model_config$lower + model_config$upper) / 2 - - # Bornes plausibles (25%-75% de la plage) - plb <- model_config$lower + 0.25 * (model_config$upper - model_config$lower) - pub <- model_config$upper - 0.25 * (model_config$upper - model_config$lower) - - # Initialisation de VBMC - vbmc_obj <- pyvbmc$VBMC( - log_density = log_posterior_wrapper, - x0 = x0, - lower_bounds = model_config$lower, - upper_bounds = model_config$upper, - plausible_lower_bounds = plb, - plausible_upper_bounds = pub - ) - - # Optimisation - vbmc_result <- vbmc_obj$optimize() - vp <- vbmc_result[[1]] # Variational posterior - results_dict <- vbmc_result[[2]] # Résultats - - # Extraction des statistiques - posterior_stats <- vp$moments() - posterior_mean <- as.vector(posterior_stats[[1]]) - posterior_sd <- sqrt(diag(posterior_stats[[2]])) - - elbo <- results_dict$elbo - elbo_sd <- results_dict$elbo_sd - - # Calcul du BIC avec la posterior mean - negLL <- qlearning_generic(posterior_mean, participant_data, model_config, return_negLL = TRUE) - - # Création du tibble de résultats - result_df <- tibble( - model = model_config$name, - n_params = model_config$n_params, - negLL = negLL, - ELBO = elbo, - ELBO_SD = elbo_sd, - AIC = 2 * negLL + 2 * model_config$n_params, - BIC = 2 * negLL + model_config$n_params * log(nrow(participant_data)), - method = "VBMC", - n_iterations = results_dict$iterations, - converged = TRUE # VBMC a ses propres critères - ) - - # Ajout des paramètres (posterior mean et SD) - for (i in 1:model_config$n_params) { - result_df[[model_config$param_names[i]]] <- posterior_mean[i] - result_df[[paste0("sd_", model_config$param_names[i])]] <- posterior_sd[i] - } - - # Stockage de la posterior complète pour visualisations - result_df$vp <- list(vp) - result_df$vbmc_results <- list(results_dict) - - return(result_df) -} - -# ============================================================================ -# VISUALISATIONS DE CONVERGENCE AVANCÉES -# ============================================================================ - -plot_convergence_detailed <- function(fit_results, participant_id = NULL) { - require(ggplot2) - require(patchwork) - - if (!is.null(participant_id)) { - fit_results <- fit_results %>% filter(participant_id == !!participant_id) - } - - plots <- list() - - # 1. Trace de convergence (negLL) - if ("convergence_range" %in% names(fit_results)) { - p1 <- fit_results %>% - mutate(participant_id = factor(participant_id)) %>% - ggplot(aes(x = participant_id, y = convergence_range, fill = converged)) + - geom_col() + - scale_y_log10() + - geom_hline(yintercept = 1, linetype = "dashed", color = "red") + - coord_flip() + - theme_minimal() + - labs(title = "Range de convergence par participant", - subtitle = "Seuil = 1.0 (ligne rouge)", - y = "Range negLL (log scale)") - plots$convergence_range <- p1 - } - - # 2. Distribution des paramètres estimés - param_cols <- grep("^alpha|^forget|^lambda|^rho", names(fit_results), value = TRUE) - param_cols <- setdiff(param_cols, grep("^se_|^sd_", param_cols, value = TRUE)) - - if (length(param_cols) > 0) { - p2 <- fit_results %>% - select(participant_id, all_of(param_cols)) %>% - pivot_longer(-participant_id, names_to = "parameter", values_to = "value") %>% - ggplot(aes(x = value, fill = parameter)) + - geom_histogram(bins = 30, alpha = 0.7) + - facet_wrap(~parameter, scales = "free", ncol = 3) + - theme_minimal() + - theme(legend.position = "none") + - labs(title = "Distribution des paramètres estimés", - x = "Valeur", y = "Fréquence") - plots$param_distribution <- p2 - } - - # 3. Corrélation paramètres vs convergence - if ("convergence_range" %in% names(fit_results) && length(param_cols) > 0) { - # Sélectionner quelques paramètres clés - key_params <- param_cols[1:min(4, length(param_cols))] - - p3 <- fit_results %>% - select(convergence_range, all_of(key_params)) %>% - pivot_longer(-convergence_range, names_to = "parameter", values_to = "value") %>% - ggplot(aes(x = value, y = convergence_range)) + - geom_point(alpha = 0.5) + - geom_smooth(method = "loess", se = FALSE, color = "red") + - scale_y_log10() + - facet_wrap(~parameter, scales = "free_x") + - theme_minimal() + - labs(title = "Convergence vs valeur des paramètres", - y = "Range convergence (log scale)") - plots$param_vs_convergence <- p3 - } - - # 4. Heatmap Hessienne (si disponible) - if ("hessian_positive_definite" %in% names(fit_results)) { - p4 <- fit_results %>% - mutate(participant_id = factor(participant_id)) %>% - ggplot(aes(x = participant_id, y = 1, fill = hessian_positive_definite)) + - geom_tile() + - scale_fill_manual(values = c("FALSE" = "red", "TRUE" = "green"), - na.value = "grey") + - coord_flip() + - theme_minimal() + - theme(axis.text.x = element_blank(), - axis.ticks.x = element_blank()) + - labs(title = "Qualité de la Hessienne", - subtitle = "Vert = définie positive, Rouge = problème", - fill = "Hessienne OK", x = "Participant", y = "") - plots$hessian_quality <- p4 - } - - # 5. Incertitude des paramètres (erreurs standard) - se_cols <- grep("^se_|^sd_", names(fit_results), value = TRUE) - if (length(se_cols) > 0) { - p5 <- fit_results %>% - select(participant_id, all_of(se_cols)) %>% - pivot_longer(-participant_id, names_to = "parameter", values_to = "se") %>% - mutate(parameter = str_remove(parameter, "^se_|^sd_")) %>% - ggplot(aes(x = se, fill = parameter)) + - geom_histogram(bins = 30, alpha = 0.7) + - facet_wrap(~parameter, scales = "free", ncol = 3) + - theme_minimal() + - theme(legend.position = "none") + - labs(title = "Distribution des erreurs standard", - subtitle = "Plus petit = meilleure précision", - x = "Erreur standard", y = "Fréquence") - plots$se_distribution <- p5 - } - - return(plots) -} - -# Visualisation spécifique VBMC (posteriors bayésiennes) -plot_vbmc_diagnostics <- function(vbmc_fit_results, participant_id) { - require(ggplot2) - require(patchwork) - - participant_data <- vbmc_fit_results %>% - filter(participant_id == !!participant_id) - - if (nrow(participant_data) == 0) { - stop("Participant non trouvé") - } - - if (!"vp" %in% names(participant_data)) { - stop("Pas de résultats VBMC disponibles (vp manquant)") - } - - vp <- participant_data$vp[[1]] - vbmc_results <- participant_data$vbmc_results[[1]] - - # Échantillonnage de la posterior - n_samples <- 10000 - samples <- vp$sample(n_samples) - - # Conversion en dataframe - param_names <- vbmc_fit_results %>% - select(starts_with("alpha"), starts_with("forget"), - starts_with("lambda"), starts_with("rho")) %>% - select(-starts_with("se_"), -starts_with("sd_")) %>% - names() - - samples_df <- as.data.frame(samples) - names(samples_df) <- param_names - - plots <- list() - - # 1. Marginal posteriors - p1 <- samples_df %>% - pivot_longer(everything(), names_to = "parameter", values_to = "value") %>% - ggplot(aes(x = value)) + - geom_histogram(aes(y = after_stat(density)), bins = 50, - fill = "steelblue", alpha = 0.7) + - geom_density(color = "red", linewidth = 1) + - facet_wrap(~parameter, scales = "free", ncol = 3) + - theme_minimal() + - labs(title = paste("Posterior distributions - Participant", participant_id), - subtitle = "Histogramme + densité estimée", - x = "Valeur", y = "Densité") - plots$marginal_posteriors <- p1 - - # 2. Pairwise correlations (pour les 4 premiers paramètres) - if (length(param_names) >= 2) { - n_plot <- min(4, length(param_names)) - pairs_data <- samples_df[, 1:n_plot] - - p2 <- GGally::ggpairs(pairs_data, - lower = list(continuous = "points"), - diag = list(continuous = "densityDiag"), - upper = list(continuous = "cor")) + - theme_minimal() + - labs(title = "Corrélations entre paramètres") - plots$pairwise_correlations <- p2 - } - - # 3. Trace de l'ELBO - if (!is.null(vbmc_results$elbo_trace)) { - elbo_trace <- vbmc_results$elbo_trace - p3 <- tibble( - iteration = seq_along(elbo_trace), - ELBO = elbo_trace - ) %>% - ggplot(aes(x = iteration, y = ELBO)) + - geom_line(color = "darkblue", linewidth = 1) + - geom_point(size = 2, alpha = 0.5) + - theme_minimal() + - labs(title = "Convergence de l'ELBO", - subtitle = "Evidence Lower Bound", - x = "Itération", y = "ELBO") - plots$elbo_trace <- p3 - } - - # 4. Incertitude des paramètres (credible intervals) - posterior_mean <- colMeans(samples_df) - posterior_sd <- apply(samples_df, 2, sd) - ci_lower <- apply(samples_df, 2, quantile, probs = 0.025) - ci_upper <- apply(samples_df, 2, quantile, probs = 0.975) - - p4 <- tibble( - parameter = param_names, - mean = posterior_mean, - sd = posterior_sd, - ci_lower = ci_lower, - ci_upper = ci_upper - ) %>% - mutate(parameter = fct_reorder(parameter, mean)) %>% - ggplot(aes(x = parameter, y = mean)) + - geom_point(size = 3) + - geom_errorbar(aes(ymin = ci_lower, ymax = ci_upper), width = 0.2) + - coord_flip() + - theme_minimal() + - labs(title = "Estimations postérieures avec intervalles de crédibilité à 95%", - x = "Paramètre", y = "Valeur") - plots$credible_intervals <- p4 - - return(plots) -} - -# Fonction pour comparer DEoptim vs VBMC -compare_optimization_methods <- function(data, participant_ids = NULL, - model_config, n_participants = 5) { - - if (is.null(participant_ids)) { - participant_ids <- sample(unique(data$participant_id), - min(n_participants, length(unique(data$participant_id)))) - } - - comparison_results <- map_df(participant_ids, function(pid) { - cat("Participant", pid, "\n") - - participant_data <- data %>% - filter(participant_id == pid) %>% - arrange(trial) - - # Méthode 1: DEoptim - fit_deoptim <- fit_participant(participant_data, model_config, n_runs = 5) - - # Méthode 2: VBMC - fit_vbmc <- tryCatch({ - fit_participant_vbmc(participant_data, model_config, use_python_vbmc = TRUE) - }, error = function(e) { - cat(" VBMC échoué pour participant", pid, ":", e$message, "\n") - return(NULL) - }) - - if (!is.null(fit_vbmc)) { - bind_rows( - fit_deoptim %>% mutate(method = "DEoptim", participant_id = pid), - fit_vbmc %>% mutate(participant_id = pid) - ) - } else { - fit_deoptim %>% mutate(method = "DEoptim", participant_id = pid) - } - }) - - # Visualisation de la comparaison - p <- comparison_results %>% - select(participant_id, method, negLL, BIC) %>% - pivot_longer(c(negLL, BIC), names_to = "metric", values_to = "value") %>% - ggplot(aes(x = method, y = value, fill = method)) + - geom_boxplot() + - facet_wrap(~metric, scales = "free_y") + - theme_minimal() + - labs(title = "Comparaison DEoptim vs VBMC", - subtitle = paste("N =", length(unique(comparison_results$participant_id)), "participants"), - y = "Valeur") - - print(p) - - return(comparison_results) - - all_results <- vector("list", n_runs) - - for (run in 1:n_runs) { + + all_results <- foreach(run = 1:n_runs, .options.future = list(seed = TRUE)) %dofuture% { set.seed(1000 * as.numeric(factor(model_config$name)) + run) - + result <- DEoptim( fn = qlearning_generic, lower = model_config$lower, @@ -639,37 +298,41 @@ compare_optimization_methods <- function(data, participant_ids = NULL, NP = max(50, model_config$n_params * 10) ) ) - - all_results[[run]] <- list( + + list( params = result$optim$bestmem, negLL = result$optim$bestval ) } - + # Sélection du meilleur run all_negLL <- sapply(all_results, function(x) x$negLL) best_run <- which.min(all_negLL) - + negLL_sd <- sd(all_negLL) negLL_range <- max(all_negLL) - min(all_negLL) - + params <- all_results[[best_run]]$params negLL <- all_results[[best_run]]$negLL - + # Calcul de la Hessienne - hessian_result <- tryCatch({ - numDeriv::hessian(qlearning_generic, params, - data = participant_data, - model_config = model_config) - }, error = function(e) NULL) - + hessian_result <- tryCatch( + { + numDeriv::hessian(qlearning_generic, params, + data = participant_data, + model_config = model_config + ) + }, + error = function(e) NULL + ) + hessian_positive_definite <- FALSE param_se <- rep(NA, model_config$n_params) - + if (!is.null(hessian_result)) { eigenvalues <- eigen(hessian_result, only.values = TRUE)$values hessian_positive_definite <- all(eigenvalues > 0) - + if (hessian_positive_definite) { param_vcov <- tryCatch(solve(hessian_result), error = function(e) NULL) if (!is.null(param_vcov)) { @@ -677,7 +340,7 @@ compare_optimization_methods <- function(data, participant_ids = NULL, } } } - + # Création du tibble de résultats result_df <- tibble( model = model_config$name, @@ -690,13 +353,17 @@ compare_optimization_methods <- function(data, participant_ids = NULL, hessian_positive_definite = hessian_positive_definite, converged = negLL_range < 1 ) - + + # Indicateurs d'événements rares observés (utile pour interprétation des rhos/alphas) + result_df$has_BS_seen <- has_BS_seen + result_df$has_JP_seen <- has_JP_seen + # Ajout des paramètres estimés for (i in 1:model_config$n_params) { result_df[[model_config$param_names[i]]] <- params[i] result_df[[paste0("se_", model_config$param_names[i])]] <- param_se[i] } - + return(result_df) } @@ -705,36 +372,35 @@ compare_optimization_methods <- function(data, participant_ids = NULL, # ============================================================================ fit_all_participants_all_models <- function(data, models_to_fit = NULL) { - model_configs <- get_model_configs() - + if (!is.null(models_to_fit)) { model_configs <- model_configs[models_to_fit] } - + participants <- unique(data$participant_id) - + all_results <- list() - + for (model_name in names(model_configs)) { cat("\n=== Fitting model:", model_name, "===\n") - + model_config <- model_configs[[model_name]] - + model_results <- map_df(participants, function(pid) { cat(" Participant", pid, "\n") - - participant_data <- data %>% + + participant_data <- data %>% filter(participant_id == pid) %>% arrange(trial) - + fit <- fit_participant(participant_data, model_config) fit %>% mutate(participant_id = pid, .before = 1) }) - + all_results[[model_name]] <- model_results } - + return(all_results) } @@ -743,11 +409,10 @@ fit_all_participants_all_models <- function(data, models_to_fit = NULL) { # ============================================================================ compare_nested_models <- function(all_results) { - # Comparaison globale global_comparison <- map_df(names(all_results), function(model_name) { results <- all_results[[model_name]] - + tibble( model = model_name, n_params = unique(results$n_params), @@ -760,10 +425,10 @@ compare_nested_models <- function(all_results) { ) }) %>% arrange(total_BIC) - + cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n") print(global_comparison) - + # Meilleur modèle par participant (BIC) best_models_per_participant <- map_df(names(all_results), function(model_name) { all_results[[model_name]] %>% @@ -772,13 +437,13 @@ compare_nested_models <- function(all_results) { group_by(participant_id) %>% slice_min(BIC, n = 1) %>% ungroup() - + cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n") print(table(best_models_per_participant$model)) - + # Comparaison par paires de modèles emboîtés cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n") - + # Exemples de paires emboîtées nested_pairs <- list( c("HOMOGENEOUS", "GAIN_LOSS"), @@ -790,15 +455,15 @@ compare_nested_models <- function(all_results) { c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"), c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX") ) - + lrt_results <- map_df(nested_pairs, function(pair) { simple_model <- pair[1] complex_model <- pair[2] - + if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) { simple_res <- all_results[[simple_model]] complex_res <- all_results[[complex_model]] - + # LRT par participant lrt_df <- simple_res %>% select(participant_id, negLL_simple = negLL, converged_simple = converged) %>% @@ -813,7 +478,7 @@ compare_nested_models <- function(all_results) { p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE), significant = p_value < 0.05 ) - + tibble( simple_model = simple_model, complex_model = complex_model, @@ -824,9 +489,563 @@ compare_nested_models <- function(all_results) { ) } }) - + print(lrt_results) - + + list( + global_comparison = global_comparison, + best_models_per_participant = best_models_per_participant, + lrt_results = lrt_results, + all_results = all_results + ) +} + +# ============================================================================ +# ESTIMATION AVEC VBMC (MÉTHODE BAYÉSIENNE) +# ============================================================================ + +fit_participant_vbmc <- function(participant_data, model_config, use_python_vbmc = TRUE) { + if (!use_python_vbmc) { + warning("VBMC nécessite Python. Utilisation de DEoptim à la place.") + return(fit_participant(participant_data, model_config)) + } + + # Import de PyVBMC via reticulate + tryCatch( + { + pyvbmc <- reticulate::import("pyvbmc") + }, + error = function(e) { + stop("PyVBMC n'est pas installé. Installez-le avec: pip install pyvbmc") + } + ) + + # Wrapper pour la fonction log-posterior - DOIT retourner un scalaire Python + log_posterior_wrapper <- function(params_array) { + # params_array vient de Python, le convertir en vecteur R + params_vec <- as.numeric(params_array) + + # Calcul de negLL (log-vraisemblance négative) + negLL <- qlearning_generic(params_vec, participant_data, model_config, return_negLL = TRUE) + + # VBMC maximise, donc retourner -negLL (log-vraisemblance) + return(-negLL) + } + + # Point de départ (milieu des bornes) + x0 <- (model_config$lower + model_config$upper) / 2 + + # Bornes plausibles (25%-75% de la plage) + plb <- model_config$lower + 0.25 * (model_config$upper - model_config$lower) + pub <- model_config$upper - 0.25 * (model_config$upper - model_config$lower) + + # Conversion en listes Python pour PyVBMC + lower_list <- as.list(model_config$lower) + upper_list <- as.list(model_config$upper) + plb_list <- as.list(plb) + pub_list <- as.list(pub) + x0_list <- as.list(x0) + + # Initialisation de VBMC avec paramètres nommés + vbmc_obj <- pyvbmc$VBMC( + log_density = log_posterior_wrapper, + x0 = x0, + lower_bounds = model_config$lower, + upper_bounds = model_config$upper, + plausible_lower_bounds = plb, + plausible_upper_bounds = pub, + options = list( + verbose = 0, + display = "off" + ) + ) + + # Optimisation + result_list <- vbmc_obj$optimize() + + # Extraction des résultats + # result_list est une liste Python : [vp, results] + vp <- result_list[[1]] # Variational posterior + results_dict <- result_list[[2]] # Dictionnaire des résultats + + # Extraction des statistiques de la posterior + posterior_moments <- vp$moments() + posterior_mean <- as.numeric(posterior_moments[[1]]) + posterior_cov <- posterior_moments[[2]] + posterior_sd <- sqrt(diag(posterior_cov)) + + # Extraction des valeurs d'optimisation + elbo <- results_dict[["elbo"]] + elbo_sd <- results_dict[["elbo_sd"]] + n_iterations <- results_dict[["iterations"]] + + # Calcul du negLL avec la posterior mean + negLL <- qlearning_generic(posterior_mean, participant_data, model_config, return_negLL = TRUE) + n_obs <- nrow(participant_data) + + # Création du tibble de résultats + result_df <- tibble( + model = model_config$name, + n_params = model_config$n_params, + negLL = negLL, + ELBO = as.numeric(elbo), + ELBO_SD = as.numeric(elbo_sd), + AIC = 2 * negLL + 2 * model_config$n_params, + BIC = 2 * negLL + model_config$n_params * log(n_obs), + method = "VBMC", + n_iterations = as.numeric(n_iterations), + converged = TRUE + ) + + # Ajout des paramètres (posterior mean et SD) + for (i in 1:model_config$n_params) { + result_df[[model_config$param_names[i]]] <- posterior_mean[i] + result_df[[paste0("sd_", model_config$param_names[i])]] <- posterior_sd[i] + } + + # Stockage de la posterior et des résultats pour visualisations ultérieures + result_df$vp <- list(vp) + result_df$vbmc_results <- list(results_dict) + + return(result_df) +} + +# ============================================================================ +# VISUALISATIONS DE CONVERGENCE AVANCÉES +# ============================================================================ + +plot_convergence_detailed <- function(fit_results, participant_id = NULL) { + require(ggplot2) + require(patchwork) + + if (!is.null(participant_id)) { + fit_results <- fit_results %>% filter(participant_id == !!participant_id) + } + + plots <- list() + + # 1. Trace de convergence (negLL) + if ("convergence_range" %in% names(fit_results)) { + p1 <- fit_results %>% + mutate(participant_id = factor(participant_id)) %>% + ggplot(aes(x = participant_id, y = convergence_range, fill = converged)) + + geom_col() + + scale_y_log10() + + geom_hline(yintercept = 1, linetype = "dashed", color = "red") + + coord_flip() + + theme_minimal() + + labs( + title = "Range de convergence par participant", + subtitle = "Seuil = 1.0 (ligne rouge)", + y = "Range negLL (log scale)" + ) + plots$convergence_range <- p1 + } + + # 2. Distribution des paramètres estimés + param_cols <- grep("^alpha|^forget|^lambda|^rho", names(fit_results), value = TRUE) + param_cols <- setdiff(param_cols, grep("^se_|^sd_", param_cols, value = TRUE)) + + if (length(param_cols) > 0) { + p2 <- fit_results %>% + select(participant_id, all_of(param_cols)) %>% + pivot_longer(-participant_id, names_to = "parameter", values_to = "value") %>% + ggplot(aes(x = value, fill = parameter)) + + geom_histogram(bins = 30, alpha = 0.7) + + facet_wrap(~parameter, scales = "free", ncol = 3) + + theme_minimal() + + theme(legend.position = "none") + + labs( + title = "Distribution des paramètres estimés", + x = "Valeur", y = "Fréquence" + ) + plots$param_distribution <- p2 + } + + # 3. Corrélation paramètres vs convergence + if ("convergence_range" %in% names(fit_results) && length(param_cols) > 0) { + # Sélectionner quelques paramètres clés + key_params <- param_cols[1:min(4, length(param_cols))] + + p3 <- fit_results %>% + select(convergence_range, all_of(key_params)) %>% + pivot_longer(-convergence_range, names_to = "parameter", values_to = "value") %>% + ggplot(aes(x = value, y = convergence_range)) + + geom_point(alpha = 0.5) + + geom_smooth(method = "loess", se = FALSE, color = "red") + + scale_y_log10() + + facet_wrap(~parameter, scales = "free_x") + + theme_minimal() + + labs( + title = "Convergence vs valeur des paramètres", + y = "Range convergence (log scale)" + ) + plots$param_vs_convergence <- p3 + } + + # 4. Heatmap Hessienne (si disponible) + if ("hessian_positive_definite" %in% names(fit_results)) { + p4 <- fit_results %>% + mutate(participant_id = factor(participant_id)) %>% + ggplot(aes(x = participant_id, y = 1, fill = hessian_positive_definite)) + + geom_tile() + + scale_fill_manual( + values = c("FALSE" = "red", "TRUE" = "green"), + na.value = "grey" + ) + + coord_flip() + + theme_minimal() + + theme( + axis.text.x = element_blank(), + axis.ticks.x = element_blank() + ) + + labs( + title = "Qualité de la Hessienne", + subtitle = "Vert = définie positive, Rouge = problème", + fill = "Hessienne OK", x = "Participant", y = "" + ) + plots$hessian_quality <- p4 + } + + # 5. Incertitude des paramètres (erreurs standard) + se_cols <- grep("^se_|^sd_", names(fit_results), value = TRUE) + if (length(se_cols) > 0) { + p5 <- fit_results %>% + select(participant_id, all_of(se_cols)) %>% + pivot_longer(-participant_id, names_to = "parameter", values_to = "se") %>% + mutate(parameter = str_remove(parameter, "^se_|^sd_")) %>% + ggplot(aes(x = se, fill = parameter)) + + geom_histogram(bins = 30, alpha = 0.7) + + facet_wrap(~parameter, scales = "free", ncol = 3) + + theme_minimal() + + theme(legend.position = "none") + + labs( + title = "Distribution des erreurs standard", + subtitle = "Plus petit = meilleure précision", + x = "Erreur standard", y = "Fréquence" + ) + plots$se_distribution <- p5 + } + + return(plots) +} + +# Visualisation spécifique VBMC (posteriors bayésiennes) +plot_vbmc_diagnostics <- function(vbmc_fit_results, participant_id) { + require(ggplot2) + require(patchwork) + + participant_data <- vbmc_fit_results %>% + filter(participant_id == !!participant_id) + + if (nrow(participant_data) == 0) { + stop("Participant non trouvé") + } + + if (!"vp" %in% names(participant_data)) { + stop("Pas de résultats VBMC disponibles (vp manquant)") + } + + vp <- participant_data$vp[[1]] + vbmc_results <- participant_data$vbmc_results[[1]] + + # Échantillonnage de la posterior variationelle + n_samples <- 10000 + samples_numpy <- vp$sample(as.integer(n_samples)) + # Convertir l'array numpy en matrice R + samples <- reticulate::py_to_r(samples_numpy) + + # Conversion en dataframe + param_names <- vbmc_fit_results %>% + select( + starts_with("alpha"), starts_with("forget"), + starts_with("lambda"), starts_with("rho") + ) %>% + select(-starts_with("se_"), -starts_with("sd_")) %>% + names() + + samples_df <- as.data.frame(samples) + names(samples_df) <- param_names[1:ncol(samples_df)] + + plots <- list() + + # 1. Marginal posteriors + p1 <- samples_df %>% + pivot_longer(everything(), names_to = "parameter", values_to = "value") %>% + ggplot(aes(x = value)) + + geom_histogram(aes(y = after_stat(density)), + bins = 50, + fill = "steelblue", alpha = 0.7 + ) + + geom_density(color = "red", linewidth = 1) + + facet_wrap(~parameter, scales = "free", ncol = 3) + + theme_minimal() + + labs( + title = paste("Posterior distributions - Participant", participant_id), + subtitle = "Histogramme + densité estimée", + x = "Valeur", y = "Densité" + ) + plots$marginal_posteriors <- p1 + + # 2. Pairwise correlations (pour les 4 premiers paramètres) + if (ncol(samples_df) >= 2) { + n_plot <- min(4, ncol(samples_df)) + pairs_data <- samples_df[, 1:n_plot] + + p2 <- GGally::ggpairs(pairs_data, + lower = list(continuous = "points"), + diag = list(continuous = "densityDiag"), + upper = list(continuous = "cor") + ) + + theme_minimal() + + labs(title = "Corrélations entre paramètres") + plots$pairwise_correlations <- p2 + } + + # 3. Trace de l'ELBO + elbo_trace <- tryCatch( + { + reticulate::py_to_r(vbmc_results[["elbo_trace"]]) + }, + error = function(e) NULL + ) + + if (!is.null(elbo_trace)) { + p3 <- tibble( + iteration = seq_along(elbo_trace), + ELBO = as.numeric(elbo_trace) + ) %>% + ggplot(aes(x = iteration, y = ELBO)) + + geom_line(color = "darkblue", linewidth = 1) + + geom_point(size = 2, alpha = 0.5) + + theme_minimal() + + labs( + title = "Convergence de l'ELBO", + subtitle = "Evidence Lower Bound", + x = "Itération", y = "ELBO" + ) + plots$elbo_trace <- p3 + } + + # 4. Incertitude des paramètres (credible intervals) + posterior_mean <- colMeans(samples_df) + posterior_sd <- apply(samples_df, 2, sd) + ci_lower <- apply(samples_df, 2, quantile, probs = 0.025) + ci_upper <- apply(samples_df, 2, quantile, probs = 0.975) + + p4 <- tibble( + parameter = names(posterior_mean), + mean = posterior_mean, + sd = posterior_sd, + ci_lower = ci_lower, + ci_upper = ci_upper + ) %>% + mutate(parameter = fct_reorder(parameter, mean)) %>% + ggplot(aes(x = parameter, y = mean)) + + geom_point(size = 3) + + geom_errorbar(aes(ymin = ci_lower, ymax = ci_upper), width = 0.2) + + coord_flip() + + theme_minimal() + + labs( + title = "Estimations postérieures avec intervalles de crédibilité à 95%", + x = "Paramètre", y = "Valeur" + ) + plots$credible_intervals <- p4 + + return(plots) +} + +# Fonction pour comparer DEoptim vs VBMC +compare_optimization_methods <- function(data, participant_ids = NULL, + model_config, n_participants = 5) { + if (is.null(participant_ids)) { + participant_ids <- sample( + unique(data$participant_id), + min(n_participants, length(unique(data$participant_id))) + ) + } + + comparison_results <- map_df(participant_ids, function(pid) { + cat("Participant", pid, "\n") + + participant_data <- data %>% + filter(participant_id == pid) %>% + arrange(trial) + + # Méthode 1: DEoptim + fit_deoptim <- fit_participant(participant_data, model_config, n_runs = 5) + + # Méthode 2: VBMC + fit_vbmc <- tryCatch( + { + fit_participant_vbmc(participant_data, model_config, use_python_vbmc = TRUE) + }, + error = function(e) { + cat(" VBMC échoué pour participant", pid, ":", e$message, "\n") + return(NULL) + } + ) + + if (!is.null(fit_vbmc)) { + bind_rows( + fit_deoptim %>% mutate(method = "DEoptim", participant_id = pid), + fit_vbmc %>% mutate(participant_id = pid) + ) + } else { + fit_deoptim %>% mutate(method = "DEoptim", participant_id = pid) + } + }) + + # Visualisation de la comparaison + p <- comparison_results %>% + select(participant_id, method, negLL, BIC) %>% + pivot_longer(c(negLL, BIC), names_to = "metric", values_to = "value") %>% + ggplot(aes(x = method, y = value, fill = method)) + + geom_boxplot() + + facet_wrap(~metric, scales = "free_y") + + theme_minimal() + + labs( + title = "Comparaison DEoptim vs VBMC", + subtitle = paste("N =", length(unique(comparison_results$participant_id)), "participants"), + y = "Valeur" + ) + + print(p) + + return(comparison_results) +} + +# ============================================================================ +# ESTIMATION POUR TOUS LES PARTICIPANTS +# ============================================================================ + +fit_all_participants_all_models <- function(data, nb_participants = NULL, models_to_fit = NULL) { + model_configs <- get_model_configs() + + if (!is.null(models_to_fit)) { + model_configs <- model_configs[models_to_fit] + } + + participants <- unique(data$participant_id) + + if (is.null(nb_participants)) { + nb_participants <- length(participants) + } + + participants <- participants[seq(1, nb_participants)] + + all_results <- list() + + for (model_name in names(model_configs)) { + cat("\n=== Fitting model:", model_name, "===\n") + + model_config <- model_configs[[model_name]] + + model_results <- map_df(participants, function(pid) { + cat(" Participant", pid, "\n") + + participant_data <- data %>% + filter(participant_id == pid) %>% + arrange(trial) + + fit <- fit_participant_vbmc(participant_data, model_config) + fit %>% mutate(participant_id = pid, .before = 1) + }) + + all_results[[model_name]] <- model_results + } + + return(all_results) +} + +# ============================================================================ +# COMPARAISON DES MODÈLES EMBOÎTÉS +# ============================================================================ + +compare_nested_models <- function(all_results) { + # Comparaison globale + global_comparison <- map_df(names(all_results), function(model_name) { + results <- all_results[[model_name]] + + tibble( + model = model_name, + n_params = unique(results$n_params), + n_converged = sum(results$converged), + mean_negLL = mean(results$negLL), + total_negLL = sum(results$negLL), + total_AIC = sum(results$AIC), + total_BIC = sum(results$BIC), + mean_convergence_range = mean(results$convergence_range) + ) + }) %>% + arrange(total_BIC) + + cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n") + print(global_comparison) + + # Meilleur modèle par participant (BIC) + best_models_per_participant <- map_df(names(all_results), function(model_name) { + all_results[[model_name]] %>% + select(participant_id, model, BIC, converged) + }) %>% + group_by(participant_id) %>% + slice_min(BIC, n = 1) %>% + ungroup() + + cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n") + print(table(best_models_per_participant$model)) + + # Comparaison par paires de modèles emboîtés + cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n") + + # Exemples de paires emboîtées + nested_pairs <- list( + c("HOMOGENEOUS", "GAIN_LOSS"), + c("GAIN_LOSS", "BIASED"), + c("GAIN_LOSS", "REE_BIASED_SIMPLE"), + c("REE_BIASED_SIMPLE", "REE_BIASED_COMPLEX"), + c("GAIN_LOSS", "REE_LEARNING_SIMPLE"), + c("REE_LEARNING_SIMPLE", "REE_LEARNING_COMPLEX"), + c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"), + c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX") + ) + + lrt_results <- map_df(nested_pairs, function(pair) { + simple_model <- pair[1] + complex_model <- pair[2] + + if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) { + simple_res <- all_results[[simple_model]] + complex_res <- all_results[[complex_model]] + + # LRT par participant + lrt_df <- simple_res %>% + select(participant_id, negLL_simple = negLL, converged_simple = converged) %>% + left_join( + complex_res %>% select(participant_id, negLL_complex = negLL, converged_complex = converged), + by = "participant_id" + ) %>% + filter(converged_simple, converged_complex) %>% + mutate( + LR_stat = 2 * (negLL_simple - negLL_complex), + df_diff = unique(complex_res$n_params) - unique(simple_res$n_params), + p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE), + significant = p_value < 0.05 + ) + + tibble( + simple_model = simple_model, + complex_model = complex_model, + n_participants = nrow(lrt_df), + pct_significant = mean(lrt_df$significant) * 100, + mean_LR = mean(lrt_df$LR_stat), + median_p = median(lrt_df$p_value) + ) + } + }) + + print(lrt_results) + list( global_comparison = global_comparison, best_models_per_participant = best_models_per_participant, @@ -842,7 +1061,7 @@ compare_nested_models <- function(all_results) { plot_model_comparison <- function(comparison_results) { require(ggplot2) require(patchwork) - + # 1. Comparaison BIC globale p1 <- comparison_results$global_comparison %>% mutate(model = fct_reorder(model, total_BIC)) %>% @@ -850,10 +1069,12 @@ plot_model_comparison <- function(comparison_results) { geom_col() + coord_flip() + theme_minimal() + - labs(title = "Comparaison globale des modèles (BIC)", - subtitle = "Plus bas = meilleur") + + labs( + title = "Comparaison globale des modèles (BIC)", + subtitle = "Plus bas = meilleur" + ) + theme(legend.position = "none") - + # 2. Meilleur modèle par participant p2 <- comparison_results$best_models_per_participant %>% count(model) %>% @@ -862,27 +1083,34 @@ plot_model_comparison <- function(comparison_results) { geom_col() + coord_flip() + theme_minimal() + - labs(title = "Meilleur modèle par participant", - y = "Nombre de participants") + + labs( + title = "Meilleur modèle par participant", + y = "Nombre de participants" + ) + theme(legend.position = "none") - + # 3. Tests LRT if (nrow(comparison_results$lrt_results) > 0) { p3 <- comparison_results$lrt_results %>% mutate(comparison = paste(simple_model, "→", complex_model)) %>% - ggplot(aes(x = fct_reorder(comparison, pct_significant), - y = pct_significant)) + + ggplot(aes( + x = fct_reorder(comparison, pct_significant), + y = pct_significant + )) + geom_col(fill = "steelblue") + geom_hline(yintercept = 50, linetype = "dashed", color = "red") + coord_flip() + theme_minimal() + - labs(title = "Tests LRT entre modèles emboîtés", - y = "% participants avec p < 0.05", - x = "Comparaison") + labs( + title = "Tests LRT entre modèles emboîtés", + y = "% participants avec p < 0.05", + x = "Comparaison" + ) } else { - p3 <- ggplot() + theme_void() + p3 <- ggplot() + + theme_void() } - + # 4. Convergence par modèle p4 <- map_df(names(comparison_results$all_results), function(model_name) { comparison_results$all_results[[model_name]] %>% @@ -893,9 +1121,11 @@ plot_model_comparison <- function(comparison_results) { scale_y_log10() + coord_flip() + theme_minimal() + - labs(title = "Convergence par modèle", - y = "Range negLL (log scale)") - + labs( + title = "Convergence par modèle", + y = "Range negLL (log scale)" + ) + (p1 | p2) / (p3 | p4) } @@ -903,30 +1133,55 @@ plot_model_comparison <- function(comparison_results) { # EXEMPLE D'UTILISATION # ============================================================================ -# Charger vos données -# data <- read_csv("votre_fichier.csv") -# Colonnes requises: participant_id, trial, choice, reward +library("future.callr") +plan("callr") -# Estimation de tous les modèles -# all_results <- fit_all_participants_all_models(data) +# Configuration de l'environnement Python +Sys.unsetenv("RETICULATE_PYTHON") +Sys.setenv(RETICULATE_PYTHON = "/home/louis/Documents/Autre/REE-RL-Lola/.venv/bin/python") -# Ou seulement certains modèles -# all_results <- fit_all_participants_all_models( -# data, -# models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE", -# "REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE") -# ) +# Charger les données directement +data <- read.csv("data/data_fourchoices.csv") %>% + rename(participant_id = participant, reward = button_value, choice = button_name) %>% + mutate( + # Mapper les choix en indices numériques + choice = case_when( + choice == "antifragile" ~ 1L, + choice == "fragile" ~ 2L, + choice == "robuste" ~ 3L, + choice == "vulnerable" ~ 4L, + TRUE ~ NA_integer_ + ), + trial = row_number() + ) %>% + select(participant_id, trial, choice, reward) %>% + arrange(participant_id, trial) + +cat("Data loaded:\n") +cat(" Participants:", n_distinct(data$participant_id), "\n") +cat(" Total trials:", nrow(data), "\n\n") + +# Test avec 2 participants et 2 modèles +all_results <- fit_all_participants_all_models( + data, + nb_participants = 10, + models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS") +) # Comparaison -# comparison <- compare_nested_models(all_results) +comparison <- compare_nested_models(all_results) # Visualisation -# plot_model_comparison(comparison) +plot_model_comparison(comparison) # Sauvegarder les résultats -# for (model_name in names(all_results)) { -# write_csv(all_results[[model_name]], -# paste0("results_", model_name, ".csv")) -# } -# write_csv(comparison$global_comparison, "global_comparison.csv") -# write_csv(comparison$best_models_per_participant, "best_models.csv") +for (model_name in names(all_results)) { + write_csv( + all_results[[model_name]], + paste0("results/results_", model_name, ".csv") + ) +} +write_csv(comparison$global_comparison, "results/global_comparison.csv") +write_csv(comparison$best_models_per_participant, "results/best_models.csv") + +cat("\nDone! Results saved to results/ directory\n")