Rewriting with parallelization
This commit is contained in:
parent
57b53629c7
commit
2670a1ecd2
1 changed files with 208 additions and 124 deletions
202
modelling V4.R
202
modelling V4.R
|
|
@ -5,21 +5,54 @@
|
|||
library(tidyverse)
|
||||
library(DEoptim)
|
||||
library(numDeriv)
|
||||
# foreach future
|
||||
library(foreach)
|
||||
library(doFuture)
|
||||
library(future.callr)
|
||||
|
||||
plan(callr, workers = future::availableCores(omit = 1L))
|
||||
|
||||
# ============================================================================
|
||||
# FONCTION GÉNÉRIQUE DE Q-LEARNING
|
||||
# ============================================================================
|
||||
|
||||
qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
||||
# Normalise noms de colonnes et conversion des choix en indices numériques
|
||||
if (!("button_value" %in% names(data)) && ("reward" %in% names(data))) {
|
||||
data$button_value <- data$reward
|
||||
}
|
||||
if (!("button_name" %in% names(data))) {
|
||||
if ("option" %in% names(data)) {
|
||||
data$button_name <- data$option
|
||||
} else if (("choice" %in% names(data)) && is.character(data$choice)) {
|
||||
data$button_name <- data$choice
|
||||
}
|
||||
}
|
||||
|
||||
# Conversion des choix en indices numériques
|
||||
if (is.factor(data$button_name) || is.character(data$button_name)) {
|
||||
# If 'choice' is numeric (prepared data), use it directly
|
||||
if (("choice" %in% names(data)) && is.numeric(data$choice)) {
|
||||
data$choice_idx <- data$choice
|
||||
} else if (is.factor(data$button_name) || is.character(data$button_name)) {
|
||||
choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable")
|
||||
data$choice_idx <- match(as.character(data$button_name), choice_levels)
|
||||
} else {
|
||||
} else if ("button_name" %in% names(data)) {
|
||||
data$choice_idx <- data$button_name
|
||||
}
|
||||
|
||||
# Robustness: if mapping produced NAs, try remapping or fail fast with large penalty
|
||||
if (any(is.na(data$choice_idx))) {
|
||||
known_levels <- c("antifragile", "fragile", "robuste", "vulnerable")
|
||||
if (!any(is.na(match(as.character(data$button_name), known_levels)))) {
|
||||
data$choice_idx <- match(as.character(data$button_name), known_levels)
|
||||
} else {
|
||||
if (return_negLL) {
|
||||
return(1e6)
|
||||
} else {
|
||||
return(-1e6)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n_arms <- 4
|
||||
n_trials <- nrow(data)
|
||||
|
||||
|
|
@ -64,13 +97,23 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
|||
|
||||
# RHO(S) - Biais pour événements rares
|
||||
if (model_config$has_rho) {
|
||||
rho_BS <- params[param_idx] # BS avoidance
|
||||
rho_JP <- params[param_idx + 1] # JP seeking
|
||||
rho_BS <- params[param_idx] # BS avoidance
|
||||
rho_JP <- params[param_idx + 1] # JP seeking
|
||||
param_idx <- param_idx + 2
|
||||
} else {
|
||||
rho_BS <- rho_JP <- 0
|
||||
}
|
||||
|
||||
# Detect if rare events actually occur in this participant's data
|
||||
has_BS_seen <- any(data$button_value == -3000, na.rm = TRUE)
|
||||
has_JP_seen <- any(data$button_value == 3000, na.rm = TRUE)
|
||||
|
||||
# If an REE type was never observed, neutralize its rho to avoid non-identifiability
|
||||
if (model_config$has_rho) {
|
||||
if (!has_BS_seen) rho_BS <- 0
|
||||
if (!has_JP_seen) rho_JP <- 0
|
||||
}
|
||||
|
||||
# Initialisation des Q-values
|
||||
Q <- rep(0, n_arms)
|
||||
log_lik <- 0
|
||||
|
|
@ -87,9 +130,9 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
|||
# Identification des options susceptibles de produire BS/JP
|
||||
# antifragile (1) = JP possible, fragile (2) = BS possible
|
||||
# vulnerable (4) = BS et JP possibles
|
||||
V[1] <- V[1] + rho_JP # antifragile
|
||||
V[2] <- V[2] + rho_BS # fragile
|
||||
V[4] <- V[4] + rho_BS + rho_JP # vulnerable
|
||||
V[1] <- V[1] + rho_JP # antifragile
|
||||
V[2] <- V[2] + rho_BS # fragile
|
||||
V[4] <- V[4] + rho_BS + rho_JP # vulnerable
|
||||
}
|
||||
|
||||
# Softmax
|
||||
|
|
@ -106,7 +149,21 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
|||
Q_new <- Q
|
||||
|
||||
# Choix de l'alpha approprié
|
||||
if (reward == -3000) {
|
||||
# if (reward == -3000) {
|
||||
# alpha_used <- alpha_BS
|
||||
# } else if (reward == 3000) {
|
||||
# alpha_used <- alpha_JP
|
||||
# } else if (reward < 0) {
|
||||
# alpha_used <- alpha_loss
|
||||
# } else {
|
||||
# alpha_used <- alpha_gain
|
||||
# }
|
||||
|
||||
# Fix when there are no extreme rewards while taking them into account
|
||||
if (is.na(reward)) {
|
||||
# skip trials with missing reward
|
||||
next
|
||||
} else if (reward == -3000) {
|
||||
alpha_used <- alpha_BS
|
||||
} else if (reward == 3000) {
|
||||
alpha_used <- alpha_JP
|
||||
|
|
@ -150,7 +207,6 @@ get_model_configs <- function() {
|
|||
lower = c(-5, -5, -3),
|
||||
upper = c(5, 5, 3)
|
||||
),
|
||||
|
||||
GAIN_LOSS = list(
|
||||
name = "GAIN_LOSS",
|
||||
n_alpha = 2,
|
||||
|
|
@ -162,7 +218,6 @@ get_model_configs <- function() {
|
|||
lower = c(-5, -5, -5, -3),
|
||||
upper = c(5, 5, 5, 3)
|
||||
),
|
||||
|
||||
BIASED = list(
|
||||
name = "BIASED",
|
||||
n_alpha = 2,
|
||||
|
|
@ -170,13 +225,14 @@ get_model_configs <- function() {
|
|||
n_lambda = 4,
|
||||
has_rho = FALSE,
|
||||
n_params = 10,
|
||||
param_names = c("alpha_loss", "alpha_gain",
|
||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4"),
|
||||
param_names = c(
|
||||
"alpha_loss", "alpha_gain",
|
||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4"
|
||||
),
|
||||
lower = c(-5, -5, rep(-5, 4), rep(-3, 4)),
|
||||
upper = c(5, 5, rep(5, 4), rep(3, 4))
|
||||
),
|
||||
|
||||
REE_BIASED_SIMPLE = list(
|
||||
name = "REE_BIASED_SIMPLE",
|
||||
n_alpha = 2,
|
||||
|
|
@ -184,12 +240,13 @@ get_model_configs <- function() {
|
|||
n_lambda = 1,
|
||||
has_rho = TRUE,
|
||||
n_params = 6,
|
||||
param_names = c("alpha_loss", "alpha_gain", "forget", "lambda",
|
||||
"rho_BS", "rho_JP"),
|
||||
param_names = c(
|
||||
"alpha_loss", "alpha_gain", "forget", "lambda",
|
||||
"rho_BS", "rho_JP"
|
||||
),
|
||||
lower = c(-5, -5, -5, -3, -10, -10),
|
||||
upper = c(5, 5, 5, 3, 10, 10)
|
||||
),
|
||||
|
||||
REE_BIASED_COMPLEX = list(
|
||||
name = "REE_BIASED_COMPLEX",
|
||||
n_alpha = 2,
|
||||
|
|
@ -197,14 +254,15 @@ get_model_configs <- function() {
|
|||
n_lambda = 4,
|
||||
has_rho = TRUE,
|
||||
n_params = 12,
|
||||
param_names = c("alpha_loss", "alpha_gain",
|
||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
||||
"rho_BS", "rho_JP"),
|
||||
param_names = c(
|
||||
"alpha_loss", "alpha_gain",
|
||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
||||
"rho_BS", "rho_JP"
|
||||
),
|
||||
lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
|
||||
upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10)
|
||||
),
|
||||
|
||||
REE_LEARNING_SIMPLE = list(
|
||||
name = "REE_LEARNING_SIMPLE",
|
||||
n_alpha = 4,
|
||||
|
|
@ -212,12 +270,13 @@ get_model_configs <- function() {
|
|||
n_lambda = 1,
|
||||
has_rho = FALSE,
|
||||
n_params = 6,
|
||||
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||
"forget", "lambda"),
|
||||
param_names = c(
|
||||
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||
"forget", "lambda"
|
||||
),
|
||||
lower = c(-5, -5, -5, -5, -5, -3),
|
||||
upper = c(5, 5, 5, 5, 5, 3)
|
||||
),
|
||||
|
||||
REE_LEARNING_COMPLEX = list(
|
||||
name = "REE_LEARNING_COMPLEX",
|
||||
n_alpha = 4,
|
||||
|
|
@ -225,13 +284,14 @@ get_model_configs <- function() {
|
|||
n_lambda = 4,
|
||||
has_rho = FALSE,
|
||||
n_params = 12,
|
||||
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4"),
|
||||
param_names = c(
|
||||
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4"
|
||||
),
|
||||
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)),
|
||||
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4))
|
||||
),
|
||||
|
||||
REE_LEARNING_BIASED_SIMPLE = list(
|
||||
name = "REE_LEARNING_BIASED_SIMPLE",
|
||||
n_alpha = 4,
|
||||
|
|
@ -239,12 +299,13 @@ get_model_configs <- function() {
|
|||
n_lambda = 1,
|
||||
has_rho = TRUE,
|
||||
n_params = 8,
|
||||
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||
"forget", "lambda", "rho_BS", "rho_JP"),
|
||||
param_names = c(
|
||||
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||
"forget", "lambda", "rho_BS", "rho_JP"
|
||||
),
|
||||
lower = c(-5, -5, -5, -5, -5, -3, -10, -10),
|
||||
upper = c(5, 5, 5, 5, 5, 3, 10, 10)
|
||||
),
|
||||
|
||||
REE_LEARNING_BIASED_COMPLEX = list(
|
||||
name = "REE_LEARNING_BIASED_COMPLEX",
|
||||
n_alpha = 4,
|
||||
|
|
@ -252,10 +313,12 @@ get_model_configs <- function() {
|
|||
n_lambda = 4,
|
||||
has_rho = TRUE,
|
||||
n_params = 14,
|
||||
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
||||
"rho_BS", "rho_JP"),
|
||||
param_names = c(
|
||||
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
||||
"rho_BS", "rho_JP"
|
||||
),
|
||||
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
|
||||
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10)
|
||||
)
|
||||
|
|
@ -267,10 +330,13 @@ get_model_configs <- function() {
|
|||
# ============================================================================
|
||||
|
||||
fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
||||
# Detect presence of rare events for this participant
|
||||
has_BS_seen <- any(participant_data$button_value == -3000, na.rm = TRUE)
|
||||
has_JP_seen <- any(participant_data$button_value == 3000, na.rm = TRUE)
|
||||
|
||||
all_results <- vector("list", n_runs)
|
||||
|
||||
for (run in 1:n_runs) {
|
||||
all_results <- foreach(run = 1:n_runs) %dofuture% {
|
||||
set.seed(1000 * as.numeric(factor(model_config$name)) + run)
|
||||
|
||||
result <- DEoptim(
|
||||
|
|
@ -287,7 +353,7 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
|||
)
|
||||
)
|
||||
|
||||
all_results[[run]] <- list(
|
||||
list(
|
||||
params = result$optim$bestmem,
|
||||
negLL = result$optim$bestval
|
||||
)
|
||||
|
|
@ -304,11 +370,15 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
|||
negLL <- all_results[[best_run]]$negLL
|
||||
|
||||
# Calcul de la Hessienne
|
||||
hessian_result <- tryCatch({
|
||||
numDeriv::hessian(qlearning_generic, params,
|
||||
data = participant_data,
|
||||
model_config = model_config)
|
||||
}, error = function(e) NULL)
|
||||
hessian_result <- tryCatch(
|
||||
{
|
||||
numDeriv::hessian(qlearning_generic, params,
|
||||
data = participant_data,
|
||||
model_config = model_config
|
||||
)
|
||||
},
|
||||
error = function(e) NULL
|
||||
)
|
||||
|
||||
hessian_positive_definite <- FALSE
|
||||
param_se <- rep(NA, model_config$n_params)
|
||||
|
|
@ -338,6 +408,10 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
|||
converged = negLL_range < 1
|
||||
)
|
||||
|
||||
# Indicateurs d'événements rares observés (utile pour interprétation des rhos/alphas)
|
||||
result_df$has_BS_seen <- has_BS_seen
|
||||
result_df$has_JP_seen <- has_JP_seen
|
||||
|
||||
# Ajout des paramètres estimés
|
||||
for (i in 1:model_config$n_params) {
|
||||
result_df[[model_config$param_names[i]]] <- params[i]
|
||||
|
|
@ -352,7 +426,6 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
|||
# ============================================================================
|
||||
|
||||
fit_all_participants_all_models <- function(data, models_to_fit = NULL) {
|
||||
|
||||
model_configs <- get_model_configs()
|
||||
|
||||
if (!is.null(models_to_fit)) {
|
||||
|
|
@ -390,7 +463,6 @@ fit_all_participants_all_models <- function(data, models_to_fit = NULL) {
|
|||
# ============================================================================
|
||||
|
||||
compare_nested_models <- function(all_results) {
|
||||
|
||||
# Comparaison globale
|
||||
global_comparison <- map_df(names(all_results), function(model_name) {
|
||||
results <- all_results[[model_name]]
|
||||
|
|
@ -497,8 +569,10 @@ plot_model_comparison <- function(comparison_results) {
|
|||
geom_col() +
|
||||
coord_flip() +
|
||||
theme_minimal() +
|
||||
labs(title = "Comparaison globale des modèles (BIC)",
|
||||
subtitle = "Plus bas = meilleur") +
|
||||
labs(
|
||||
title = "Comparaison globale des modèles (BIC)",
|
||||
subtitle = "Plus bas = meilleur"
|
||||
) +
|
||||
theme(legend.position = "none")
|
||||
|
||||
# 2. Meilleur modèle par participant
|
||||
|
|
@ -509,25 +583,32 @@ plot_model_comparison <- function(comparison_results) {
|
|||
geom_col() +
|
||||
coord_flip() +
|
||||
theme_minimal() +
|
||||
labs(title = "Meilleur modèle par participant",
|
||||
y = "Nombre de participants") +
|
||||
labs(
|
||||
title = "Meilleur modèle par participant",
|
||||
y = "Nombre de participants"
|
||||
) +
|
||||
theme(legend.position = "none")
|
||||
|
||||
# 3. Tests LRT
|
||||
if (nrow(comparison_results$lrt_results) > 0) {
|
||||
p3 <- comparison_results$lrt_results %>%
|
||||
mutate(comparison = paste(simple_model, "→", complex_model)) %>%
|
||||
ggplot(aes(x = fct_reorder(comparison, pct_significant),
|
||||
y = pct_significant)) +
|
||||
ggplot(aes(
|
||||
x = fct_reorder(comparison, pct_significant),
|
||||
y = pct_significant
|
||||
)) +
|
||||
geom_col(fill = "steelblue") +
|
||||
geom_hline(yintercept = 50, linetype = "dashed", color = "red") +
|
||||
coord_flip() +
|
||||
theme_minimal() +
|
||||
labs(title = "Tests LRT entre modèles emboîtés",
|
||||
y = "% participants avec p < 0.05",
|
||||
x = "Comparaison")
|
||||
labs(
|
||||
title = "Tests LRT entre modèles emboîtés",
|
||||
y = "% participants avec p < 0.05",
|
||||
x = "Comparaison"
|
||||
)
|
||||
} else {
|
||||
p3 <- ggplot() + theme_void()
|
||||
p3 <- ggplot() +
|
||||
theme_void()
|
||||
}
|
||||
|
||||
# 4. Convergence par modèle
|
||||
|
|
@ -540,8 +621,10 @@ plot_model_comparison <- function(comparison_results) {
|
|||
scale_y_log10() +
|
||||
coord_flip() +
|
||||
theme_minimal() +
|
||||
labs(title = "Convergence par modèle",
|
||||
y = "Range negLL (log scale)")
|
||||
labs(
|
||||
title = "Convergence par modèle",
|
||||
y = "Range negLL (log scale)"
|
||||
)
|
||||
|
||||
(p1 | p2) / (p3 | p4)
|
||||
}
|
||||
|
|
@ -553,10 +636,11 @@ plot_model_comparison <- function(comparison_results) {
|
|||
# Charger vos données
|
||||
# data <- read_csv("votre_fichier.csv")
|
||||
# Colonnes requises: participant_id, trial, choice, reward
|
||||
source("load_data.R")
|
||||
|
||||
# Estimation de tous les modèles
|
||||
# all_results <- fit_all_participants_all_models(data)
|
||||
|
||||
fit_all_participants_all_models(data %>% filter(participant_id == "qfmtmjjy"))
|
||||
# Ou seulement certains modèles
|
||||
# all_results <- fit_all_participants_all_models(
|
||||
# data,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue