932 lines
30 KiB
R
932 lines
30 KiB
R
# ============================================================================
|
|
# MODÈLES Q-LEARNING EMBOÎTÉS POUR DÉCISION AVEC ÉVÉNEMENTS RARES
|
|
# ============================================================================
|
|
|
|
library(tidyverse)
|
|
library(DEoptim)
|
|
library(numDeriv)
|
|
library(reticulate) # Pour interfacer avec PyVBMC
|
|
|
|
# Configuration optionnelle de l'environnement Python
|
|
# use_python("/usr/bin/python3") # Ajuster selon votre installation
|
|
# py_install("pyvbmc") # Installer PyVBMC si nécessaire
|
|
|
|
# ============================================================================
|
|
# FONCTION GÉNÉRIQUE DE Q-LEARNING
|
|
# ============================================================================
|
|
|
|
qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
|
|
|
# Conversion des choix en indices numériques
|
|
if (is.factor(data$choice) || is.character(data$choice)) {
|
|
choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable")
|
|
data$choice_idx <- match(as.character(data$choice), choice_levels)
|
|
} else {
|
|
data$choice_idx <- data$choice
|
|
}
|
|
|
|
n_arms <- 4
|
|
n_trials <- nrow(data)
|
|
|
|
# Extraction des paramètres selon la configuration du modèle
|
|
param_idx <- 1
|
|
|
|
# ALPHA(S)
|
|
if (model_config$n_alpha == 1) {
|
|
alpha_loss <- alpha_gain <- alpha_BS <- alpha_JP <- plogis(params[param_idx])
|
|
param_idx <- param_idx + 1
|
|
} else if (model_config$n_alpha == 2) {
|
|
alpha_loss <- plogis(params[param_idx])
|
|
alpha_gain <- plogis(params[param_idx + 1])
|
|
alpha_BS <- alpha_loss
|
|
alpha_JP <- alpha_gain
|
|
param_idx <- param_idx + 2
|
|
} else if (model_config$n_alpha == 4) {
|
|
alpha_loss <- plogis(params[param_idx])
|
|
alpha_gain <- plogis(params[param_idx + 1])
|
|
alpha_BS <- plogis(params[param_idx + 2])
|
|
alpha_JP <- plogis(params[param_idx + 3])
|
|
param_idx <- param_idx + 4
|
|
}
|
|
|
|
# FORGET(S)
|
|
if (model_config$n_forget == 1) {
|
|
forget <- rep(plogis(params[param_idx]), n_arms)
|
|
param_idx <- param_idx + 1
|
|
} else if (model_config$n_forget == 4) {
|
|
forget <- plogis(params[param_idx:(param_idx + 3)])
|
|
param_idx <- param_idx + 4
|
|
}
|
|
|
|
# LAMBDA(S)
|
|
if (model_config$n_lambda == 1) {
|
|
lambda <- rep(exp(params[param_idx]), n_arms)
|
|
param_idx <- param_idx + 1
|
|
} else if (model_config$n_lambda == 4) {
|
|
lambda <- exp(params[param_idx:(param_idx + 3)])
|
|
param_idx <- param_idx + 4
|
|
}
|
|
|
|
# RHO(S) - Biais pour événements rares
|
|
if (model_config$has_rho) {
|
|
rho_BS <- params[param_idx] # BS avoidance
|
|
rho_JP <- params[param_idx + 1] # JP seeking
|
|
param_idx <- param_idx + 2
|
|
} else {
|
|
rho_BS <- rho_JP <- 0
|
|
}
|
|
|
|
# Initialisation des Q-values
|
|
Q <- rep(0, n_arms)
|
|
log_lik <- 0
|
|
|
|
for (t in 1:n_trials) {
|
|
choice <- data$choice_idx[t]
|
|
reward <- data$reward[t]
|
|
|
|
# Calcul des valeurs subjectives V(t)
|
|
V <- lambda * Q
|
|
|
|
# Ajout des biais pour événements rares si le modèle le permet
|
|
if (model_config$has_rho) {
|
|
# Identification des options susceptibles de produire BS/JP
|
|
# antifragile (1) = JP possible, fragile (2) = BS possible
|
|
# vulnerable (4) = BS et JP possibles
|
|
V[1] <- V[1] + rho_JP # antifragile
|
|
V[2] <- V[2] + rho_BS # fragile
|
|
V[4] <- V[4] + rho_BS + rho_JP # vulnerable
|
|
}
|
|
|
|
# Softmax
|
|
V_max <- max(V)
|
|
exp_V <- exp(V - V_max)
|
|
probs <- exp_V / sum(exp_V)
|
|
probs <- pmax(probs, 1e-10)
|
|
probs <- probs / sum(probs)
|
|
|
|
# Log-likelihood
|
|
log_lik <- log_lik + log(probs[choice])
|
|
|
|
# Mise à jour Q-learning
|
|
Q_new <- Q
|
|
|
|
# Choix de l'alpha approprié
|
|
if (reward == -3000) {
|
|
alpha_used <- alpha_BS
|
|
} else if (reward == 3000) {
|
|
alpha_used <- alpha_JP
|
|
} else if (reward < 0) {
|
|
alpha_used <- alpha_loss
|
|
} else {
|
|
alpha_used <- alpha_gain
|
|
}
|
|
|
|
# Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t))
|
|
Q_new[choice] <- Q[choice] + alpha_used * (reward - Q[choice])
|
|
|
|
# Options non choisies : Q(t+1) = Q(t) * (1 - f)
|
|
not_chosen <- setdiff(1:n_arms, choice)
|
|
Q_new[not_chosen] <- Q[not_chosen] * (1 - forget[not_chosen])
|
|
|
|
Q <- Q_new
|
|
}
|
|
|
|
if (return_negLL) {
|
|
return(-log_lik)
|
|
} else {
|
|
return(log_lik)
|
|
}
|
|
}
|
|
|
|
# ============================================================================
|
|
# CONFIGURATIONS DES MODÈLES EMBOÎTÉS
|
|
# ============================================================================
|
|
|
|
get_model_configs <- function() {
|
|
list(
|
|
HOMOGENEOUS = list(
|
|
name = "HOMOGENEOUS",
|
|
n_alpha = 1,
|
|
n_forget = 1,
|
|
n_lambda = 1,
|
|
has_rho = FALSE,
|
|
n_params = 3,
|
|
param_names = c("alpha", "forget", "lambda"),
|
|
lower = c(-5, -5, -3),
|
|
upper = c(5, 5, 3)
|
|
),
|
|
|
|
GAIN_LOSS = list(
|
|
name = "GAIN_LOSS",
|
|
n_alpha = 2,
|
|
n_forget = 1,
|
|
n_lambda = 1,
|
|
has_rho = FALSE,
|
|
n_params = 4,
|
|
param_names = c("alpha_loss", "alpha_gain", "forget", "lambda"),
|
|
lower = c(-5, -5, -5, -3),
|
|
upper = c(5, 5, 5, 3)
|
|
),
|
|
|
|
BIASED = list(
|
|
name = "BIASED",
|
|
n_alpha = 2,
|
|
n_forget = 4,
|
|
n_lambda = 4,
|
|
has_rho = FALSE,
|
|
n_params = 10,
|
|
param_names = c("alpha_loss", "alpha_gain",
|
|
"forget_1", "forget_2", "forget_3", "forget_4",
|
|
"lambda_1", "lambda_2", "lambda_3", "lambda_4"),
|
|
lower = c(-5, -5, rep(-5, 4), rep(-3, 4)),
|
|
upper = c(5, 5, rep(5, 4), rep(3, 4))
|
|
),
|
|
|
|
REE_BIASED_SIMPLE = list(
|
|
name = "REE_BIASED_SIMPLE",
|
|
n_alpha = 2,
|
|
n_forget = 1,
|
|
n_lambda = 1,
|
|
has_rho = TRUE,
|
|
n_params = 6,
|
|
param_names = c("alpha_loss", "alpha_gain", "forget", "lambda",
|
|
"rho_BS", "rho_JP"),
|
|
lower = c(-5, -5, -5, -3, -10, -10),
|
|
upper = c(5, 5, 5, 3, 10, 10)
|
|
),
|
|
|
|
REE_BIASED_COMPLEX = list(
|
|
name = "REE_BIASED_COMPLEX",
|
|
n_alpha = 2,
|
|
n_forget = 4,
|
|
n_lambda = 4,
|
|
has_rho = TRUE,
|
|
n_params = 12,
|
|
param_names = c("alpha_loss", "alpha_gain",
|
|
"forget_1", "forget_2", "forget_3", "forget_4",
|
|
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
|
"rho_BS", "rho_JP"),
|
|
lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
|
|
upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10)
|
|
),
|
|
|
|
REE_LEARNING_SIMPLE = list(
|
|
name = "REE_LEARNING_SIMPLE",
|
|
n_alpha = 4,
|
|
n_forget = 1,
|
|
n_lambda = 1,
|
|
has_rho = FALSE,
|
|
n_params = 6,
|
|
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
|
"forget", "lambda"),
|
|
lower = c(-5, -5, -5, -5, -5, -3),
|
|
upper = c(5, 5, 5, 5, 5, 3)
|
|
),
|
|
|
|
REE_LEARNING_COMPLEX = list(
|
|
name = "REE_LEARNING_COMPLEX",
|
|
n_alpha = 4,
|
|
n_forget = 4,
|
|
n_lambda = 4,
|
|
has_rho = FALSE,
|
|
n_params = 12,
|
|
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
|
"forget_1", "forget_2", "forget_3", "forget_4",
|
|
"lambda_1", "lambda_2", "lambda_3", "lambda_4"),
|
|
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)),
|
|
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4))
|
|
),
|
|
|
|
REE_LEARNING_BIASED_SIMPLE = list(
|
|
name = "REE_LEARNING_BIASED_SIMPLE",
|
|
n_alpha = 4,
|
|
n_forget = 1,
|
|
n_lambda = 1,
|
|
has_rho = TRUE,
|
|
n_params = 8,
|
|
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
|
"forget", "lambda", "rho_BS", "rho_JP"),
|
|
lower = c(-5, -5, -5, -5, -5, -3, -10, -10),
|
|
upper = c(5, 5, 5, 5, 5, 3, 10, 10)
|
|
),
|
|
|
|
REE_LEARNING_BIASED_COMPLEX = list(
|
|
name = "REE_LEARNING_BIASED_COMPLEX",
|
|
n_alpha = 4,
|
|
n_forget = 4,
|
|
n_lambda = 4,
|
|
has_rho = TRUE,
|
|
n_params = 14,
|
|
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
|
"forget_1", "forget_2", "forget_3", "forget_4",
|
|
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
|
"rho_BS", "rho_JP"),
|
|
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
|
|
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10)
|
|
)
|
|
)
|
|
}
|
|
|
|
# ============================================================================
|
|
# ESTIMATION AVEC VBMC (MÉTHODE BAYÉSIENNE)
|
|
# ============================================================================
|
|
|
|
fit_participant_vbmc <- function(participant_data, model_config, use_python_vbmc = TRUE) {
|
|
|
|
if (!use_python_vbmc) {
|
|
warning("VBMC nécessite Python. Utilisation de DEoptim à la place.")
|
|
return(fit_participant(participant_data, model_config))
|
|
}
|
|
|
|
# Import de PyVBMC
|
|
tryCatch({
|
|
pyvbmc <- import("pyvbmc")
|
|
}, error = function(e) {
|
|
stop("PyVBMC n'est pas installé. Installez-le avec: py_install('pyvbmc')")
|
|
})
|
|
|
|
# Wrapper pour la fonction log-posterior
|
|
log_posterior_wrapper <- function(params_array) {
|
|
# PyVBMC passe un array numpy, on le convertit en vecteur R
|
|
params_vec <- as.vector(params_array)
|
|
|
|
# Retourne la log-vraisemblance (négatif de negLL)
|
|
negLL <- qlearning_generic(params_vec, participant_data, model_config, return_negLL = TRUE)
|
|
return(-negLL) # VBMC maximise, donc on retourne -negLL
|
|
}
|
|
|
|
# Point de départ (milieu des bornes plausibles)
|
|
x0 <- (model_config$lower + model_config$upper) / 2
|
|
|
|
# Bornes plausibles (25%-75% de la plage)
|
|
plb <- model_config$lower + 0.25 * (model_config$upper - model_config$lower)
|
|
pub <- model_config$upper - 0.25 * (model_config$upper - model_config$lower)
|
|
|
|
# Initialisation de VBMC
|
|
vbmc_obj <- pyvbmc$VBMC(
|
|
log_density = log_posterior_wrapper,
|
|
x0 = x0,
|
|
lower_bounds = model_config$lower,
|
|
upper_bounds = model_config$upper,
|
|
plausible_lower_bounds = plb,
|
|
plausible_upper_bounds = pub
|
|
)
|
|
|
|
# Optimisation
|
|
vbmc_result <- vbmc_obj$optimize()
|
|
vp <- vbmc_result[[1]] # Variational posterior
|
|
results_dict <- vbmc_result[[2]] # Résultats
|
|
|
|
# Extraction des statistiques
|
|
posterior_stats <- vp$moments()
|
|
posterior_mean <- as.vector(posterior_stats[[1]])
|
|
posterior_sd <- sqrt(diag(posterior_stats[[2]]))
|
|
|
|
elbo <- results_dict$elbo
|
|
elbo_sd <- results_dict$elbo_sd
|
|
|
|
# Calcul du BIC avec la posterior mean
|
|
negLL <- qlearning_generic(posterior_mean, participant_data, model_config, return_negLL = TRUE)
|
|
|
|
# Création du tibble de résultats
|
|
result_df <- tibble(
|
|
model = model_config$name,
|
|
n_params = model_config$n_params,
|
|
negLL = negLL,
|
|
ELBO = elbo,
|
|
ELBO_SD = elbo_sd,
|
|
AIC = 2 * negLL + 2 * model_config$n_params,
|
|
BIC = 2 * negLL + model_config$n_params * log(nrow(participant_data)),
|
|
method = "VBMC",
|
|
n_iterations = results_dict$iterations,
|
|
converged = TRUE # VBMC a ses propres critères
|
|
)
|
|
|
|
# Ajout des paramètres (posterior mean et SD)
|
|
for (i in 1:model_config$n_params) {
|
|
result_df[[model_config$param_names[i]]] <- posterior_mean[i]
|
|
result_df[[paste0("sd_", model_config$param_names[i])]] <- posterior_sd[i]
|
|
}
|
|
|
|
# Stockage de la posterior complète pour visualisations
|
|
result_df$vp <- list(vp)
|
|
result_df$vbmc_results <- list(results_dict)
|
|
|
|
return(result_df)
|
|
}
|
|
|
|
# ============================================================================
|
|
# VISUALISATIONS DE CONVERGENCE AVANCÉES
|
|
# ============================================================================
|
|
|
|
plot_convergence_detailed <- function(fit_results, participant_id = NULL) {
|
|
require(ggplot2)
|
|
require(patchwork)
|
|
|
|
if (!is.null(participant_id)) {
|
|
fit_results <- fit_results %>% filter(participant_id == !!participant_id)
|
|
}
|
|
|
|
plots <- list()
|
|
|
|
# 1. Trace de convergence (negLL)
|
|
if ("convergence_range" %in% names(fit_results)) {
|
|
p1 <- fit_results %>%
|
|
mutate(participant_id = factor(participant_id)) %>%
|
|
ggplot(aes(x = participant_id, y = convergence_range, fill = converged)) +
|
|
geom_col() +
|
|
scale_y_log10() +
|
|
geom_hline(yintercept = 1, linetype = "dashed", color = "red") +
|
|
coord_flip() +
|
|
theme_minimal() +
|
|
labs(title = "Range de convergence par participant",
|
|
subtitle = "Seuil = 1.0 (ligne rouge)",
|
|
y = "Range negLL (log scale)")
|
|
plots$convergence_range <- p1
|
|
}
|
|
|
|
# 2. Distribution des paramètres estimés
|
|
param_cols <- grep("^alpha|^forget|^lambda|^rho", names(fit_results), value = TRUE)
|
|
param_cols <- setdiff(param_cols, grep("^se_|^sd_", param_cols, value = TRUE))
|
|
|
|
if (length(param_cols) > 0) {
|
|
p2 <- fit_results %>%
|
|
select(participant_id, all_of(param_cols)) %>%
|
|
pivot_longer(-participant_id, names_to = "parameter", values_to = "value") %>%
|
|
ggplot(aes(x = value, fill = parameter)) +
|
|
geom_histogram(bins = 30, alpha = 0.7) +
|
|
facet_wrap(~parameter, scales = "free", ncol = 3) +
|
|
theme_minimal() +
|
|
theme(legend.position = "none") +
|
|
labs(title = "Distribution des paramètres estimés",
|
|
x = "Valeur", y = "Fréquence")
|
|
plots$param_distribution <- p2
|
|
}
|
|
|
|
# 3. Corrélation paramètres vs convergence
|
|
if ("convergence_range" %in% names(fit_results) && length(param_cols) > 0) {
|
|
# Sélectionner quelques paramètres clés
|
|
key_params <- param_cols[1:min(4, length(param_cols))]
|
|
|
|
p3 <- fit_results %>%
|
|
select(convergence_range, all_of(key_params)) %>%
|
|
pivot_longer(-convergence_range, names_to = "parameter", values_to = "value") %>%
|
|
ggplot(aes(x = value, y = convergence_range)) +
|
|
geom_point(alpha = 0.5) +
|
|
geom_smooth(method = "loess", se = FALSE, color = "red") +
|
|
scale_y_log10() +
|
|
facet_wrap(~parameter, scales = "free_x") +
|
|
theme_minimal() +
|
|
labs(title = "Convergence vs valeur des paramètres",
|
|
y = "Range convergence (log scale)")
|
|
plots$param_vs_convergence <- p3
|
|
}
|
|
|
|
# 4. Heatmap Hessienne (si disponible)
|
|
if ("hessian_positive_definite" %in% names(fit_results)) {
|
|
p4 <- fit_results %>%
|
|
mutate(participant_id = factor(participant_id)) %>%
|
|
ggplot(aes(x = participant_id, y = 1, fill = hessian_positive_definite)) +
|
|
geom_tile() +
|
|
scale_fill_manual(values = c("FALSE" = "red", "TRUE" = "green"),
|
|
na.value = "grey") +
|
|
coord_flip() +
|
|
theme_minimal() +
|
|
theme(axis.text.x = element_blank(),
|
|
axis.ticks.x = element_blank()) +
|
|
labs(title = "Qualité de la Hessienne",
|
|
subtitle = "Vert = définie positive, Rouge = problème",
|
|
fill = "Hessienne OK", x = "Participant", y = "")
|
|
plots$hessian_quality <- p4
|
|
}
|
|
|
|
# 5. Incertitude des paramètres (erreurs standard)
|
|
se_cols <- grep("^se_|^sd_", names(fit_results), value = TRUE)
|
|
if (length(se_cols) > 0) {
|
|
p5 <- fit_results %>%
|
|
select(participant_id, all_of(se_cols)) %>%
|
|
pivot_longer(-participant_id, names_to = "parameter", values_to = "se") %>%
|
|
mutate(parameter = str_remove(parameter, "^se_|^sd_")) %>%
|
|
ggplot(aes(x = se, fill = parameter)) +
|
|
geom_histogram(bins = 30, alpha = 0.7) +
|
|
facet_wrap(~parameter, scales = "free", ncol = 3) +
|
|
theme_minimal() +
|
|
theme(legend.position = "none") +
|
|
labs(title = "Distribution des erreurs standard",
|
|
subtitle = "Plus petit = meilleure précision",
|
|
x = "Erreur standard", y = "Fréquence")
|
|
plots$se_distribution <- p5
|
|
}
|
|
|
|
return(plots)
|
|
}
|
|
|
|
# Visualisation spécifique VBMC (posteriors bayésiennes)
|
|
plot_vbmc_diagnostics <- function(vbmc_fit_results, participant_id) {
|
|
require(ggplot2)
|
|
require(patchwork)
|
|
|
|
participant_data <- vbmc_fit_results %>%
|
|
filter(participant_id == !!participant_id)
|
|
|
|
if (nrow(participant_data) == 0) {
|
|
stop("Participant non trouvé")
|
|
}
|
|
|
|
if (!"vp" %in% names(participant_data)) {
|
|
stop("Pas de résultats VBMC disponibles (vp manquant)")
|
|
}
|
|
|
|
vp <- participant_data$vp[[1]]
|
|
vbmc_results <- participant_data$vbmc_results[[1]]
|
|
|
|
# Échantillonnage de la posterior
|
|
n_samples <- 10000
|
|
samples <- vp$sample(n_samples)
|
|
|
|
# Conversion en dataframe
|
|
param_names <- vbmc_fit_results %>%
|
|
select(starts_with("alpha"), starts_with("forget"),
|
|
starts_with("lambda"), starts_with("rho")) %>%
|
|
select(-starts_with("se_"), -starts_with("sd_")) %>%
|
|
names()
|
|
|
|
samples_df <- as.data.frame(samples)
|
|
names(samples_df) <- param_names
|
|
|
|
plots <- list()
|
|
|
|
# 1. Marginal posteriors
|
|
p1 <- samples_df %>%
|
|
pivot_longer(everything(), names_to = "parameter", values_to = "value") %>%
|
|
ggplot(aes(x = value)) +
|
|
geom_histogram(aes(y = after_stat(density)), bins = 50,
|
|
fill = "steelblue", alpha = 0.7) +
|
|
geom_density(color = "red", linewidth = 1) +
|
|
facet_wrap(~parameter, scales = "free", ncol = 3) +
|
|
theme_minimal() +
|
|
labs(title = paste("Posterior distributions - Participant", participant_id),
|
|
subtitle = "Histogramme + densité estimée",
|
|
x = "Valeur", y = "Densité")
|
|
plots$marginal_posteriors <- p1
|
|
|
|
# 2. Pairwise correlations (pour les 4 premiers paramètres)
|
|
if (length(param_names) >= 2) {
|
|
n_plot <- min(4, length(param_names))
|
|
pairs_data <- samples_df[, 1:n_plot]
|
|
|
|
p2 <- GGally::ggpairs(pairs_data,
|
|
lower = list(continuous = "points"),
|
|
diag = list(continuous = "densityDiag"),
|
|
upper = list(continuous = "cor")) +
|
|
theme_minimal() +
|
|
labs(title = "Corrélations entre paramètres")
|
|
plots$pairwise_correlations <- p2
|
|
}
|
|
|
|
# 3. Trace de l'ELBO
|
|
if (!is.null(vbmc_results$elbo_trace)) {
|
|
elbo_trace <- vbmc_results$elbo_trace
|
|
p3 <- tibble(
|
|
iteration = seq_along(elbo_trace),
|
|
ELBO = elbo_trace
|
|
) %>%
|
|
ggplot(aes(x = iteration, y = ELBO)) +
|
|
geom_line(color = "darkblue", linewidth = 1) +
|
|
geom_point(size = 2, alpha = 0.5) +
|
|
theme_minimal() +
|
|
labs(title = "Convergence de l'ELBO",
|
|
subtitle = "Evidence Lower Bound",
|
|
x = "Itération", y = "ELBO")
|
|
plots$elbo_trace <- p3
|
|
}
|
|
|
|
# 4. Incertitude des paramètres (credible intervals)
|
|
posterior_mean <- colMeans(samples_df)
|
|
posterior_sd <- apply(samples_df, 2, sd)
|
|
ci_lower <- apply(samples_df, 2, quantile, probs = 0.025)
|
|
ci_upper <- apply(samples_df, 2, quantile, probs = 0.975)
|
|
|
|
p4 <- tibble(
|
|
parameter = param_names,
|
|
mean = posterior_mean,
|
|
sd = posterior_sd,
|
|
ci_lower = ci_lower,
|
|
ci_upper = ci_upper
|
|
) %>%
|
|
mutate(parameter = fct_reorder(parameter, mean)) %>%
|
|
ggplot(aes(x = parameter, y = mean)) +
|
|
geom_point(size = 3) +
|
|
geom_errorbar(aes(ymin = ci_lower, ymax = ci_upper), width = 0.2) +
|
|
coord_flip() +
|
|
theme_minimal() +
|
|
labs(title = "Estimations postérieures avec intervalles de crédibilité à 95%",
|
|
x = "Paramètre", y = "Valeur")
|
|
plots$credible_intervals <- p4
|
|
|
|
return(plots)
|
|
}
|
|
|
|
# Fonction pour comparer DEoptim vs VBMC
|
|
compare_optimization_methods <- function(data, participant_ids = NULL,
|
|
model_config, n_participants = 5) {
|
|
|
|
if (is.null(participant_ids)) {
|
|
participant_ids <- sample(unique(data$participant_id),
|
|
min(n_participants, length(unique(data$participant_id))))
|
|
}
|
|
|
|
comparison_results <- map_df(participant_ids, function(pid) {
|
|
cat("Participant", pid, "\n")
|
|
|
|
participant_data <- data %>%
|
|
filter(participant_id == pid) %>%
|
|
arrange(trial)
|
|
|
|
# Méthode 1: DEoptim
|
|
fit_deoptim <- fit_participant(participant_data, model_config, n_runs = 5)
|
|
|
|
# Méthode 2: VBMC
|
|
fit_vbmc <- tryCatch({
|
|
fit_participant_vbmc(participant_data, model_config, use_python_vbmc = TRUE)
|
|
}, error = function(e) {
|
|
cat(" VBMC échoué pour participant", pid, ":", e$message, "\n")
|
|
return(NULL)
|
|
})
|
|
|
|
if (!is.null(fit_vbmc)) {
|
|
bind_rows(
|
|
fit_deoptim %>% mutate(method = "DEoptim", participant_id = pid),
|
|
fit_vbmc %>% mutate(participant_id = pid)
|
|
)
|
|
} else {
|
|
fit_deoptim %>% mutate(method = "DEoptim", participant_id = pid)
|
|
}
|
|
})
|
|
|
|
# Visualisation de la comparaison
|
|
p <- comparison_results %>%
|
|
select(participant_id, method, negLL, BIC) %>%
|
|
pivot_longer(c(negLL, BIC), names_to = "metric", values_to = "value") %>%
|
|
ggplot(aes(x = method, y = value, fill = method)) +
|
|
geom_boxplot() +
|
|
facet_wrap(~metric, scales = "free_y") +
|
|
theme_minimal() +
|
|
labs(title = "Comparaison DEoptim vs VBMC",
|
|
subtitle = paste("N =", length(unique(comparison_results$participant_id)), "participants"),
|
|
y = "Valeur")
|
|
|
|
print(p)
|
|
|
|
return(comparison_results)
|
|
|
|
|
|
all_results <- vector("list", n_runs)
|
|
|
|
for (run in 1:n_runs) {
|
|
set.seed(1000 * as.numeric(factor(model_config$name)) + run)
|
|
|
|
result <- DEoptim(
|
|
fn = qlearning_generic,
|
|
lower = model_config$lower,
|
|
upper = model_config$upper,
|
|
data = participant_data,
|
|
model_config = model_config,
|
|
control = DEoptim.control(
|
|
itermax = 200,
|
|
trace = FALSE,
|
|
parallelType = 0,
|
|
NP = max(50, model_config$n_params * 10)
|
|
)
|
|
)
|
|
|
|
all_results[[run]] <- list(
|
|
params = result$optim$bestmem,
|
|
negLL = result$optim$bestval
|
|
)
|
|
}
|
|
|
|
# Sélection du meilleur run
|
|
all_negLL <- sapply(all_results, function(x) x$negLL)
|
|
best_run <- which.min(all_negLL)
|
|
|
|
negLL_sd <- sd(all_negLL)
|
|
negLL_range <- max(all_negLL) - min(all_negLL)
|
|
|
|
params <- all_results[[best_run]]$params
|
|
negLL <- all_results[[best_run]]$negLL
|
|
|
|
# Calcul de la Hessienne
|
|
hessian_result <- tryCatch({
|
|
numDeriv::hessian(qlearning_generic, params,
|
|
data = participant_data,
|
|
model_config = model_config)
|
|
}, error = function(e) NULL)
|
|
|
|
hessian_positive_definite <- FALSE
|
|
param_se <- rep(NA, model_config$n_params)
|
|
|
|
if (!is.null(hessian_result)) {
|
|
eigenvalues <- eigen(hessian_result, only.values = TRUE)$values
|
|
hessian_positive_definite <- all(eigenvalues > 0)
|
|
|
|
if (hessian_positive_definite) {
|
|
param_vcov <- tryCatch(solve(hessian_result), error = function(e) NULL)
|
|
if (!is.null(param_vcov)) {
|
|
param_se <- sqrt(diag(param_vcov))
|
|
}
|
|
}
|
|
}
|
|
|
|
# Création du tibble de résultats
|
|
result_df <- tibble(
|
|
model = model_config$name,
|
|
n_params = model_config$n_params,
|
|
negLL = negLL,
|
|
AIC = 2 * negLL + 2 * model_config$n_params,
|
|
BIC = 2 * negLL + model_config$n_params * log(nrow(participant_data)),
|
|
convergence_sd = negLL_sd,
|
|
convergence_range = negLL_range,
|
|
hessian_positive_definite = hessian_positive_definite,
|
|
converged = negLL_range < 1
|
|
)
|
|
|
|
# Ajout des paramètres estimés
|
|
for (i in 1:model_config$n_params) {
|
|
result_df[[model_config$param_names[i]]] <- params[i]
|
|
result_df[[paste0("se_", model_config$param_names[i])]] <- param_se[i]
|
|
}
|
|
|
|
return(result_df)
|
|
}
|
|
|
|
# ============================================================================
|
|
# ESTIMATION POUR TOUS LES PARTICIPANTS
|
|
# ============================================================================
|
|
|
|
fit_all_participants_all_models <- function(data, models_to_fit = NULL) {
|
|
|
|
model_configs <- get_model_configs()
|
|
|
|
if (!is.null(models_to_fit)) {
|
|
model_configs <- model_configs[models_to_fit]
|
|
}
|
|
|
|
participants <- unique(data$participant_id)
|
|
|
|
all_results <- list()
|
|
|
|
for (model_name in names(model_configs)) {
|
|
cat("\n=== Fitting model:", model_name, "===\n")
|
|
|
|
model_config <- model_configs[[model_name]]
|
|
|
|
model_results <- map_df(participants, function(pid) {
|
|
cat(" Participant", pid, "\n")
|
|
|
|
participant_data <- data %>%
|
|
filter(participant_id == pid) %>%
|
|
arrange(trial)
|
|
|
|
fit <- fit_participant(participant_data, model_config)
|
|
fit %>% mutate(participant_id = pid, .before = 1)
|
|
})
|
|
|
|
all_results[[model_name]] <- model_results
|
|
}
|
|
|
|
return(all_results)
|
|
}
|
|
|
|
# ============================================================================
|
|
# COMPARAISON DES MODÈLES EMBOÎTÉS
|
|
# ============================================================================
|
|
|
|
compare_nested_models <- function(all_results) {
|
|
|
|
# Comparaison globale
|
|
global_comparison <- map_df(names(all_results), function(model_name) {
|
|
results <- all_results[[model_name]]
|
|
|
|
tibble(
|
|
model = model_name,
|
|
n_params = unique(results$n_params),
|
|
n_converged = sum(results$converged),
|
|
mean_negLL = mean(results$negLL),
|
|
total_negLL = sum(results$negLL),
|
|
total_AIC = sum(results$AIC),
|
|
total_BIC = sum(results$BIC),
|
|
mean_convergence_range = mean(results$convergence_range)
|
|
)
|
|
}) %>%
|
|
arrange(total_BIC)
|
|
|
|
cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n")
|
|
print(global_comparison)
|
|
|
|
# Meilleur modèle par participant (BIC)
|
|
best_models_per_participant <- map_df(names(all_results), function(model_name) {
|
|
all_results[[model_name]] %>%
|
|
select(participant_id, model, BIC, converged)
|
|
}) %>%
|
|
group_by(participant_id) %>%
|
|
slice_min(BIC, n = 1) %>%
|
|
ungroup()
|
|
|
|
cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n")
|
|
print(table(best_models_per_participant$model))
|
|
|
|
# Comparaison par paires de modèles emboîtés
|
|
cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n")
|
|
|
|
# Exemples de paires emboîtées
|
|
nested_pairs <- list(
|
|
c("HOMOGENEOUS", "GAIN_LOSS"),
|
|
c("GAIN_LOSS", "BIASED"),
|
|
c("GAIN_LOSS", "REE_BIASED_SIMPLE"),
|
|
c("REE_BIASED_SIMPLE", "REE_BIASED_COMPLEX"),
|
|
c("GAIN_LOSS", "REE_LEARNING_SIMPLE"),
|
|
c("REE_LEARNING_SIMPLE", "REE_LEARNING_COMPLEX"),
|
|
c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"),
|
|
c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX")
|
|
)
|
|
|
|
lrt_results <- map_df(nested_pairs, function(pair) {
|
|
simple_model <- pair[1]
|
|
complex_model <- pair[2]
|
|
|
|
if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) {
|
|
simple_res <- all_results[[simple_model]]
|
|
complex_res <- all_results[[complex_model]]
|
|
|
|
# LRT par participant
|
|
lrt_df <- simple_res %>%
|
|
select(participant_id, negLL_simple = negLL, converged_simple = converged) %>%
|
|
left_join(
|
|
complex_res %>% select(participant_id, negLL_complex = negLL, converged_complex = converged),
|
|
by = "participant_id"
|
|
) %>%
|
|
filter(converged_simple, converged_complex) %>%
|
|
mutate(
|
|
LR_stat = 2 * (negLL_simple - negLL_complex),
|
|
df_diff = unique(complex_res$n_params) - unique(simple_res$n_params),
|
|
p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE),
|
|
significant = p_value < 0.05
|
|
)
|
|
|
|
tibble(
|
|
simple_model = simple_model,
|
|
complex_model = complex_model,
|
|
n_participants = nrow(lrt_df),
|
|
pct_significant = mean(lrt_df$significant) * 100,
|
|
mean_LR = mean(lrt_df$LR_stat),
|
|
median_p = median(lrt_df$p_value)
|
|
)
|
|
}
|
|
})
|
|
|
|
print(lrt_results)
|
|
|
|
list(
|
|
global_comparison = global_comparison,
|
|
best_models_per_participant = best_models_per_participant,
|
|
lrt_results = lrt_results,
|
|
all_results = all_results
|
|
)
|
|
}
|
|
|
|
# ============================================================================
|
|
# VISUALISATION
|
|
# ============================================================================
|
|
|
|
plot_model_comparison <- function(comparison_results) {
|
|
require(ggplot2)
|
|
require(patchwork)
|
|
|
|
# 1. Comparaison BIC globale
|
|
p1 <- comparison_results$global_comparison %>%
|
|
mutate(model = fct_reorder(model, total_BIC)) %>%
|
|
ggplot(aes(x = model, y = total_BIC, fill = model)) +
|
|
geom_col() +
|
|
coord_flip() +
|
|
theme_minimal() +
|
|
labs(title = "Comparaison globale des modèles (BIC)",
|
|
subtitle = "Plus bas = meilleur") +
|
|
theme(legend.position = "none")
|
|
|
|
# 2. Meilleur modèle par participant
|
|
p2 <- comparison_results$best_models_per_participant %>%
|
|
count(model) %>%
|
|
mutate(model = fct_reorder(model, n)) %>%
|
|
ggplot(aes(x = model, y = n, fill = model)) +
|
|
geom_col() +
|
|
coord_flip() +
|
|
theme_minimal() +
|
|
labs(title = "Meilleur modèle par participant",
|
|
y = "Nombre de participants") +
|
|
theme(legend.position = "none")
|
|
|
|
# 3. Tests LRT
|
|
if (nrow(comparison_results$lrt_results) > 0) {
|
|
p3 <- comparison_results$lrt_results %>%
|
|
mutate(comparison = paste(simple_model, "→", complex_model)) %>%
|
|
ggplot(aes(x = fct_reorder(comparison, pct_significant),
|
|
y = pct_significant)) +
|
|
geom_col(fill = "steelblue") +
|
|
geom_hline(yintercept = 50, linetype = "dashed", color = "red") +
|
|
coord_flip() +
|
|
theme_minimal() +
|
|
labs(title = "Tests LRT entre modèles emboîtés",
|
|
y = "% participants avec p < 0.05",
|
|
x = "Comparaison")
|
|
} else {
|
|
p3 <- ggplot() + theme_void()
|
|
}
|
|
|
|
# 4. Convergence par modèle
|
|
p4 <- map_df(names(comparison_results$all_results), function(model_name) {
|
|
comparison_results$all_results[[model_name]] %>%
|
|
select(model, convergence_range, converged)
|
|
}) %>%
|
|
ggplot(aes(x = model, y = convergence_range, fill = converged)) +
|
|
geom_boxplot() +
|
|
scale_y_log10() +
|
|
coord_flip() +
|
|
theme_minimal() +
|
|
labs(title = "Convergence par modèle",
|
|
y = "Range negLL (log scale)")
|
|
|
|
(p1 | p2) / (p3 | p4)
|
|
}
|
|
|
|
# ============================================================================
|
|
# EXEMPLE D'UTILISATION
|
|
# ============================================================================
|
|
|
|
# Charger vos données
|
|
# data <- read_csv("votre_fichier.csv")
|
|
# Colonnes requises: participant_id, trial, choice, reward
|
|
|
|
# Estimation de tous les modèles
|
|
# all_results <- fit_all_participants_all_models(data)
|
|
|
|
# Ou seulement certains modèles
|
|
# all_results <- fit_all_participants_all_models(
|
|
# data,
|
|
# models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE",
|
|
# "REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE")
|
|
# )
|
|
|
|
# Comparaison
|
|
# comparison <- compare_nested_models(all_results)
|
|
|
|
# Visualisation
|
|
# plot_model_comparison(comparison)
|
|
|
|
# Sauvegarder les résultats
|
|
# for (model_name in names(all_results)) {
|
|
# write_csv(all_results[[model_name]],
|
|
# paste0("results_", model_name, ".csv"))
|
|
# }
|
|
# write_csv(comparison$global_comparison, "global_comparison.csv")
|
|
# write_csv(comparison$best_models_per_participant, "best_models.csv")
|