Rewriting with parallelization
This commit is contained in:
parent
57b53629c7
commit
2670a1ecd2
1 changed files with 208 additions and 124 deletions
332
modelling V4.R
332
modelling V4.R
|
|
@ -5,27 +5,60 @@
|
||||||
library(tidyverse)
|
library(tidyverse)
|
||||||
library(DEoptim)
|
library(DEoptim)
|
||||||
library(numDeriv)
|
library(numDeriv)
|
||||||
|
# foreach future
|
||||||
|
library(foreach)
|
||||||
|
library(doFuture)
|
||||||
|
library(future.callr)
|
||||||
|
|
||||||
|
plan(callr, workers = future::availableCores(omit = 1L))
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# FONCTION GÉNÉRIQUE DE Q-LEARNING
|
# FONCTION GÉNÉRIQUE DE Q-LEARNING
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
||||||
|
# Normalise noms de colonnes et conversion des choix en indices numériques
|
||||||
# Conversion des choix en indices numériques
|
if (!("button_value" %in% names(data)) && ("reward" %in% names(data))) {
|
||||||
if (is.factor(data$button_name) || is.character(data$button_name)) {
|
data$button_value <- data$reward
|
||||||
|
}
|
||||||
|
if (!("button_name" %in% names(data))) {
|
||||||
|
if ("option" %in% names(data)) {
|
||||||
|
data$button_name <- data$option
|
||||||
|
} else if (("choice" %in% names(data)) && is.character(data$choice)) {
|
||||||
|
data$button_name <- data$choice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# If 'choice' is numeric (prepared data), use it directly
|
||||||
|
if (("choice" %in% names(data)) && is.numeric(data$choice)) {
|
||||||
|
data$choice_idx <- data$choice
|
||||||
|
} else if (is.factor(data$button_name) || is.character(data$button_name)) {
|
||||||
choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable")
|
choice_levels <- c("antifragile", "fragile", "robuste", "vulnerable")
|
||||||
data$choice_idx <- match(as.character(data$button_name), choice_levels)
|
data$choice_idx <- match(as.character(data$button_name), choice_levels)
|
||||||
} else {
|
} else if ("button_name" %in% names(data)) {
|
||||||
data$choice_idx <- data$button_name
|
data$choice_idx <- data$button_name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Robustness: if mapping produced NAs, try remapping or fail fast with large penalty
|
||||||
|
if (any(is.na(data$choice_idx))) {
|
||||||
|
known_levels <- c("antifragile", "fragile", "robuste", "vulnerable")
|
||||||
|
if (!any(is.na(match(as.character(data$button_name), known_levels)))) {
|
||||||
|
data$choice_idx <- match(as.character(data$button_name), known_levels)
|
||||||
|
} else {
|
||||||
|
if (return_negLL) {
|
||||||
|
return(1e6)
|
||||||
|
} else {
|
||||||
|
return(-1e6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
n_arms <- 4
|
n_arms <- 4
|
||||||
n_trials <- nrow(data)
|
n_trials <- nrow(data)
|
||||||
|
|
||||||
# Extraction des paramètres selon la configuration du modèle
|
# Extraction des paramètres selon la configuration du modèle
|
||||||
param_idx <- 1
|
param_idx <- 1
|
||||||
|
|
||||||
# ALPHA(S)
|
# ALPHA(S)
|
||||||
if (model_config$n_alpha == 1) {
|
if (model_config$n_alpha == 1) {
|
||||||
alpha_loss <- alpha_gain <- alpha_BS <- alpha_JP <- plogis(params[param_idx])
|
alpha_loss <- alpha_gain <- alpha_BS <- alpha_JP <- plogis(params[param_idx])
|
||||||
|
|
@ -43,7 +76,7 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
||||||
alpha_JP <- plogis(params[param_idx + 3])
|
alpha_JP <- plogis(params[param_idx + 3])
|
||||||
param_idx <- param_idx + 4
|
param_idx <- param_idx + 4
|
||||||
}
|
}
|
||||||
|
|
||||||
# FORGET(S)
|
# FORGET(S)
|
||||||
if (model_config$n_forget == 1) {
|
if (model_config$n_forget == 1) {
|
||||||
forget <- rep(plogis(params[param_idx]), n_arms)
|
forget <- rep(plogis(params[param_idx]), n_arms)
|
||||||
|
|
@ -52,7 +85,7 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
||||||
forget <- plogis(params[param_idx:(param_idx + 3)])
|
forget <- plogis(params[param_idx:(param_idx + 3)])
|
||||||
param_idx <- param_idx + 4
|
param_idx <- param_idx + 4
|
||||||
}
|
}
|
||||||
|
|
||||||
# LAMBDA(S)
|
# LAMBDA(S)
|
||||||
if (model_config$n_lambda == 1) {
|
if (model_config$n_lambda == 1) {
|
||||||
lambda <- rep(exp(params[param_idx]), n_arms)
|
lambda <- rep(exp(params[param_idx]), n_arms)
|
||||||
|
|
@ -61,52 +94,76 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
||||||
lambda <- exp(params[param_idx:(param_idx + 3)])
|
lambda <- exp(params[param_idx:(param_idx + 3)])
|
||||||
param_idx <- param_idx + 4
|
param_idx <- param_idx + 4
|
||||||
}
|
}
|
||||||
|
|
||||||
# RHO(S) - Biais pour événements rares
|
# RHO(S) - Biais pour événements rares
|
||||||
if (model_config$has_rho) {
|
if (model_config$has_rho) {
|
||||||
rho_BS <- params[param_idx] # BS avoidance
|
rho_BS <- params[param_idx] # BS avoidance
|
||||||
rho_JP <- params[param_idx + 1] # JP seeking
|
rho_JP <- params[param_idx + 1] # JP seeking
|
||||||
param_idx <- param_idx + 2
|
param_idx <- param_idx + 2
|
||||||
} else {
|
} else {
|
||||||
rho_BS <- rho_JP <- 0
|
rho_BS <- rho_JP <- 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Detect if rare events actually occur in this participant's data
|
||||||
|
has_BS_seen <- any(data$button_value == -3000, na.rm = TRUE)
|
||||||
|
has_JP_seen <- any(data$button_value == 3000, na.rm = TRUE)
|
||||||
|
|
||||||
|
# If an REE type was never observed, neutralize its rho to avoid non-identifiability
|
||||||
|
if (model_config$has_rho) {
|
||||||
|
if (!has_BS_seen) rho_BS <- 0
|
||||||
|
if (!has_JP_seen) rho_JP <- 0
|
||||||
|
}
|
||||||
|
|
||||||
# Initialisation des Q-values
|
# Initialisation des Q-values
|
||||||
Q <- rep(0, n_arms)
|
Q <- rep(0, n_arms)
|
||||||
log_lik <- 0
|
log_lik <- 0
|
||||||
|
|
||||||
for (t in 1:n_trials) {
|
for (t in 1:n_trials) {
|
||||||
choice <- data$choice_idx[t]
|
choice <- data$choice_idx[t]
|
||||||
reward <- data$button_value[t]
|
reward <- data$button_value[t]
|
||||||
|
|
||||||
# Calcul des valeurs subjectives V(t)
|
# Calcul des valeurs subjectives V(t)
|
||||||
V <- lambda * Q
|
V <- lambda * Q
|
||||||
|
|
||||||
# Ajout des biais pour événements rares si le modèle le permet
|
# Ajout des biais pour événements rares si le modèle le permet
|
||||||
if (model_config$has_rho) {
|
if (model_config$has_rho) {
|
||||||
# Identification des options susceptibles de produire BS/JP
|
# Identification des options susceptibles de produire BS/JP
|
||||||
# antifragile (1) = JP possible, fragile (2) = BS possible
|
# antifragile (1) = JP possible, fragile (2) = BS possible
|
||||||
# vulnerable (4) = BS et JP possibles
|
# vulnerable (4) = BS et JP possibles
|
||||||
V[1] <- V[1] + rho_JP # antifragile
|
V[1] <- V[1] + rho_JP # antifragile
|
||||||
V[2] <- V[2] + rho_BS # fragile
|
V[2] <- V[2] + rho_BS # fragile
|
||||||
V[4] <- V[4] + rho_BS + rho_JP # vulnerable
|
V[4] <- V[4] + rho_BS + rho_JP # vulnerable
|
||||||
}
|
}
|
||||||
|
|
||||||
# Softmax
|
# Softmax
|
||||||
V_max <- max(V)
|
V_max <- max(V)
|
||||||
exp_V <- exp(V - V_max)
|
exp_V <- exp(V - V_max)
|
||||||
probs <- exp_V / sum(exp_V)
|
probs <- exp_V / sum(exp_V)
|
||||||
probs <- pmax(probs, 1e-10)
|
probs <- pmax(probs, 1e-10)
|
||||||
probs <- probs / sum(probs)
|
probs <- probs / sum(probs)
|
||||||
|
|
||||||
# Log-likelihood
|
# Log-likelihood
|
||||||
log_lik <- log_lik + log(probs[choice])
|
log_lik <- log_lik + log(probs[choice])
|
||||||
|
|
||||||
# Mise à jour Q-learning
|
# Mise à jour Q-learning
|
||||||
Q_new <- Q
|
Q_new <- Q
|
||||||
|
|
||||||
# Choix de l'alpha approprié
|
# Choix de l'alpha approprié
|
||||||
if (reward == -3000) {
|
# if (reward == -3000) {
|
||||||
|
# alpha_used <- alpha_BS
|
||||||
|
# } else if (reward == 3000) {
|
||||||
|
# alpha_used <- alpha_JP
|
||||||
|
# } else if (reward < 0) {
|
||||||
|
# alpha_used <- alpha_loss
|
||||||
|
# } else {
|
||||||
|
# alpha_used <- alpha_gain
|
||||||
|
# }
|
||||||
|
|
||||||
|
# Fix when there are no extreme rewards while taking them into account
|
||||||
|
if (is.na(reward)) {
|
||||||
|
# skip trials with missing reward
|
||||||
|
next
|
||||||
|
} else if (reward == -3000) {
|
||||||
alpha_used <- alpha_BS
|
alpha_used <- alpha_BS
|
||||||
} else if (reward == 3000) {
|
} else if (reward == 3000) {
|
||||||
alpha_used <- alpha_JP
|
alpha_used <- alpha_JP
|
||||||
|
|
@ -115,17 +172,17 @@ qlearning_generic <- function(params, data, model_config, return_negLL = TRUE) {
|
||||||
} else {
|
} else {
|
||||||
alpha_used <- alpha_gain
|
alpha_used <- alpha_gain
|
||||||
}
|
}
|
||||||
|
|
||||||
# Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t))
|
# Option choisie : Q(t+1) = Q(t) + alpha * (r(t) - Q(t))
|
||||||
Q_new[choice] <- Q[choice] + alpha_used * (reward - Q[choice])
|
Q_new[choice] <- Q[choice] + alpha_used * (reward - Q[choice])
|
||||||
|
|
||||||
# Options non choisies : Q(t+1) = Q(t) * (1 - f)
|
# Options non choisies : Q(t+1) = Q(t) * (1 - f)
|
||||||
not_chosen <- setdiff(1:n_arms, choice)
|
not_chosen <- setdiff(1:n_arms, choice)
|
||||||
Q_new[not_chosen] <- Q[not_chosen] * (1 - forget[not_chosen])
|
Q_new[not_chosen] <- Q[not_chosen] * (1 - forget[not_chosen])
|
||||||
|
|
||||||
Q <- Q_new
|
Q <- Q_new
|
||||||
}
|
}
|
||||||
|
|
||||||
if (return_negLL) {
|
if (return_negLL) {
|
||||||
return(-log_lik)
|
return(-log_lik)
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -150,7 +207,6 @@ get_model_configs <- function() {
|
||||||
lower = c(-5, -5, -3),
|
lower = c(-5, -5, -3),
|
||||||
upper = c(5, 5, 3)
|
upper = c(5, 5, 3)
|
||||||
),
|
),
|
||||||
|
|
||||||
GAIN_LOSS = list(
|
GAIN_LOSS = list(
|
||||||
name = "GAIN_LOSS",
|
name = "GAIN_LOSS",
|
||||||
n_alpha = 2,
|
n_alpha = 2,
|
||||||
|
|
@ -162,7 +218,6 @@ get_model_configs <- function() {
|
||||||
lower = c(-5, -5, -5, -3),
|
lower = c(-5, -5, -5, -3),
|
||||||
upper = c(5, 5, 5, 3)
|
upper = c(5, 5, 5, 3)
|
||||||
),
|
),
|
||||||
|
|
||||||
BIASED = list(
|
BIASED = list(
|
||||||
name = "BIASED",
|
name = "BIASED",
|
||||||
n_alpha = 2,
|
n_alpha = 2,
|
||||||
|
|
@ -170,13 +225,14 @@ get_model_configs <- function() {
|
||||||
n_lambda = 4,
|
n_lambda = 4,
|
||||||
has_rho = FALSE,
|
has_rho = FALSE,
|
||||||
n_params = 10,
|
n_params = 10,
|
||||||
param_names = c("alpha_loss", "alpha_gain",
|
param_names = c(
|
||||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
"alpha_loss", "alpha_gain",
|
||||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4"),
|
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||||
|
"lambda_1", "lambda_2", "lambda_3", "lambda_4"
|
||||||
|
),
|
||||||
lower = c(-5, -5, rep(-5, 4), rep(-3, 4)),
|
lower = c(-5, -5, rep(-5, 4), rep(-3, 4)),
|
||||||
upper = c(5, 5, rep(5, 4), rep(3, 4))
|
upper = c(5, 5, rep(5, 4), rep(3, 4))
|
||||||
),
|
),
|
||||||
|
|
||||||
REE_BIASED_SIMPLE = list(
|
REE_BIASED_SIMPLE = list(
|
||||||
name = "REE_BIASED_SIMPLE",
|
name = "REE_BIASED_SIMPLE",
|
||||||
n_alpha = 2,
|
n_alpha = 2,
|
||||||
|
|
@ -184,12 +240,13 @@ get_model_configs <- function() {
|
||||||
n_lambda = 1,
|
n_lambda = 1,
|
||||||
has_rho = TRUE,
|
has_rho = TRUE,
|
||||||
n_params = 6,
|
n_params = 6,
|
||||||
param_names = c("alpha_loss", "alpha_gain", "forget", "lambda",
|
param_names = c(
|
||||||
"rho_BS", "rho_JP"),
|
"alpha_loss", "alpha_gain", "forget", "lambda",
|
||||||
|
"rho_BS", "rho_JP"
|
||||||
|
),
|
||||||
lower = c(-5, -5, -5, -3, -10, -10),
|
lower = c(-5, -5, -5, -3, -10, -10),
|
||||||
upper = c(5, 5, 5, 3, 10, 10)
|
upper = c(5, 5, 5, 3, 10, 10)
|
||||||
),
|
),
|
||||||
|
|
||||||
REE_BIASED_COMPLEX = list(
|
REE_BIASED_COMPLEX = list(
|
||||||
name = "REE_BIASED_COMPLEX",
|
name = "REE_BIASED_COMPLEX",
|
||||||
n_alpha = 2,
|
n_alpha = 2,
|
||||||
|
|
@ -197,14 +254,15 @@ get_model_configs <- function() {
|
||||||
n_lambda = 4,
|
n_lambda = 4,
|
||||||
has_rho = TRUE,
|
has_rho = TRUE,
|
||||||
n_params = 12,
|
n_params = 12,
|
||||||
param_names = c("alpha_loss", "alpha_gain",
|
param_names = c(
|
||||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
"alpha_loss", "alpha_gain",
|
||||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||||
"rho_BS", "rho_JP"),
|
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
||||||
|
"rho_BS", "rho_JP"
|
||||||
|
),
|
||||||
lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
|
lower = c(-5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
|
||||||
upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10)
|
upper = c(5, 5, rep(5, 4), rep(3, 4), 10, 10)
|
||||||
),
|
),
|
||||||
|
|
||||||
REE_LEARNING_SIMPLE = list(
|
REE_LEARNING_SIMPLE = list(
|
||||||
name = "REE_LEARNING_SIMPLE",
|
name = "REE_LEARNING_SIMPLE",
|
||||||
n_alpha = 4,
|
n_alpha = 4,
|
||||||
|
|
@ -212,12 +270,13 @@ get_model_configs <- function() {
|
||||||
n_lambda = 1,
|
n_lambda = 1,
|
||||||
has_rho = FALSE,
|
has_rho = FALSE,
|
||||||
n_params = 6,
|
n_params = 6,
|
||||||
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
param_names = c(
|
||||||
"forget", "lambda"),
|
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||||
|
"forget", "lambda"
|
||||||
|
),
|
||||||
lower = c(-5, -5, -5, -5, -5, -3),
|
lower = c(-5, -5, -5, -5, -5, -3),
|
||||||
upper = c(5, 5, 5, 5, 5, 3)
|
upper = c(5, 5, 5, 5, 5, 3)
|
||||||
),
|
),
|
||||||
|
|
||||||
REE_LEARNING_COMPLEX = list(
|
REE_LEARNING_COMPLEX = list(
|
||||||
name = "REE_LEARNING_COMPLEX",
|
name = "REE_LEARNING_COMPLEX",
|
||||||
n_alpha = 4,
|
n_alpha = 4,
|
||||||
|
|
@ -225,13 +284,14 @@ get_model_configs <- function() {
|
||||||
n_lambda = 4,
|
n_lambda = 4,
|
||||||
has_rho = FALSE,
|
has_rho = FALSE,
|
||||||
n_params = 12,
|
n_params = 12,
|
||||||
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
param_names = c(
|
||||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4"),
|
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||||
|
"lambda_1", "lambda_2", "lambda_3", "lambda_4"
|
||||||
|
),
|
||||||
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)),
|
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4)),
|
||||||
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4))
|
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4))
|
||||||
),
|
),
|
||||||
|
|
||||||
REE_LEARNING_BIASED_SIMPLE = list(
|
REE_LEARNING_BIASED_SIMPLE = list(
|
||||||
name = "REE_LEARNING_BIASED_SIMPLE",
|
name = "REE_LEARNING_BIASED_SIMPLE",
|
||||||
n_alpha = 4,
|
n_alpha = 4,
|
||||||
|
|
@ -239,12 +299,13 @@ get_model_configs <- function() {
|
||||||
n_lambda = 1,
|
n_lambda = 1,
|
||||||
has_rho = TRUE,
|
has_rho = TRUE,
|
||||||
n_params = 8,
|
n_params = 8,
|
||||||
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
param_names = c(
|
||||||
"forget", "lambda", "rho_BS", "rho_JP"),
|
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||||
|
"forget", "lambda", "rho_BS", "rho_JP"
|
||||||
|
),
|
||||||
lower = c(-5, -5, -5, -5, -5, -3, -10, -10),
|
lower = c(-5, -5, -5, -5, -5, -3, -10, -10),
|
||||||
upper = c(5, 5, 5, 5, 5, 3, 10, 10)
|
upper = c(5, 5, 5, 5, 5, 3, 10, 10)
|
||||||
),
|
),
|
||||||
|
|
||||||
REE_LEARNING_BIASED_COMPLEX = list(
|
REE_LEARNING_BIASED_COMPLEX = list(
|
||||||
name = "REE_LEARNING_BIASED_COMPLEX",
|
name = "REE_LEARNING_BIASED_COMPLEX",
|
||||||
n_alpha = 4,
|
n_alpha = 4,
|
||||||
|
|
@ -252,10 +313,12 @@ get_model_configs <- function() {
|
||||||
n_lambda = 4,
|
n_lambda = 4,
|
||||||
has_rho = TRUE,
|
has_rho = TRUE,
|
||||||
n_params = 14,
|
n_params = 14,
|
||||||
param_names = c("alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
param_names = c(
|
||||||
"forget_1", "forget_2", "forget_3", "forget_4",
|
"alpha_loss", "alpha_gain", "alpha_BS", "alpha_JP",
|
||||||
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
"forget_1", "forget_2", "forget_3", "forget_4",
|
||||||
"rho_BS", "rho_JP"),
|
"lambda_1", "lambda_2", "lambda_3", "lambda_4",
|
||||||
|
"rho_BS", "rho_JP"
|
||||||
|
),
|
||||||
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
|
lower = c(-5, -5, -5, -5, rep(-5, 4), rep(-3, 4), -10, -10),
|
||||||
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10)
|
upper = c(5, 5, 5, 5, rep(5, 4), rep(3, 4), 10, 10)
|
||||||
)
|
)
|
||||||
|
|
@ -267,12 +330,15 @@ get_model_configs <- function() {
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
||||||
|
# Detect presence of rare events for this participant
|
||||||
|
has_BS_seen <- any(participant_data$button_value == -3000, na.rm = TRUE)
|
||||||
|
has_JP_seen <- any(participant_data$button_value == 3000, na.rm = TRUE)
|
||||||
|
|
||||||
all_results <- vector("list", n_runs)
|
all_results <- vector("list", n_runs)
|
||||||
|
|
||||||
for (run in 1:n_runs) {
|
all_results <- foreach(run = 1:n_runs) %dofuture% {
|
||||||
set.seed(1000 * as.numeric(factor(model_config$name)) + run)
|
set.seed(1000 * as.numeric(factor(model_config$name)) + run)
|
||||||
|
|
||||||
result <- DEoptim(
|
result <- DEoptim(
|
||||||
fn = qlearning_generic,
|
fn = qlearning_generic,
|
||||||
lower = model_config$lower,
|
lower = model_config$lower,
|
||||||
|
|
@ -286,37 +352,41 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
||||||
NP = max(50, model_config$n_params * 10)
|
NP = max(50, model_config$n_params * 10)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
all_results[[run]] <- list(
|
list(
|
||||||
params = result$optim$bestmem,
|
params = result$optim$bestmem,
|
||||||
negLL = result$optim$bestval
|
negLL = result$optim$bestval
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sélection du meilleur run
|
# Sélection du meilleur run
|
||||||
all_negLL <- sapply(all_results, function(x) x$negLL)
|
all_negLL <- sapply(all_results, function(x) x$negLL)
|
||||||
best_run <- which.min(all_negLL)
|
best_run <- which.min(all_negLL)
|
||||||
|
|
||||||
negLL_sd <- sd(all_negLL)
|
negLL_sd <- sd(all_negLL)
|
||||||
negLL_range <- max(all_negLL) - min(all_negLL)
|
negLL_range <- max(all_negLL) - min(all_negLL)
|
||||||
|
|
||||||
params <- all_results[[best_run]]$params
|
params <- all_results[[best_run]]$params
|
||||||
negLL <- all_results[[best_run]]$negLL
|
negLL <- all_results[[best_run]]$negLL
|
||||||
|
|
||||||
# Calcul de la Hessienne
|
# Calcul de la Hessienne
|
||||||
hessian_result <- tryCatch({
|
hessian_result <- tryCatch(
|
||||||
numDeriv::hessian(qlearning_generic, params,
|
{
|
||||||
data = participant_data,
|
numDeriv::hessian(qlearning_generic, params,
|
||||||
model_config = model_config)
|
data = participant_data,
|
||||||
}, error = function(e) NULL)
|
model_config = model_config
|
||||||
|
)
|
||||||
|
},
|
||||||
|
error = function(e) NULL
|
||||||
|
)
|
||||||
|
|
||||||
hessian_positive_definite <- FALSE
|
hessian_positive_definite <- FALSE
|
||||||
param_se <- rep(NA, model_config$n_params)
|
param_se <- rep(NA, model_config$n_params)
|
||||||
|
|
||||||
if (!is.null(hessian_result)) {
|
if (!is.null(hessian_result)) {
|
||||||
eigenvalues <- eigen(hessian_result, only.values = TRUE)$values
|
eigenvalues <- eigen(hessian_result, only.values = TRUE)$values
|
||||||
hessian_positive_definite <- all(eigenvalues > 0)
|
hessian_positive_definite <- all(eigenvalues > 0)
|
||||||
|
|
||||||
if (hessian_positive_definite) {
|
if (hessian_positive_definite) {
|
||||||
param_vcov <- tryCatch(solve(hessian_result), error = function(e) NULL)
|
param_vcov <- tryCatch(solve(hessian_result), error = function(e) NULL)
|
||||||
if (!is.null(param_vcov)) {
|
if (!is.null(param_vcov)) {
|
||||||
|
|
@ -324,7 +394,7 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Création du tibble de résultats
|
# Création du tibble de résultats
|
||||||
result_df <- tibble(
|
result_df <- tibble(
|
||||||
model = model_config$name,
|
model = model_config$name,
|
||||||
|
|
@ -337,13 +407,17 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
||||||
hessian_positive_definite = hessian_positive_definite,
|
hessian_positive_definite = hessian_positive_definite,
|
||||||
converged = negLL_range < 1
|
converged = negLL_range < 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Indicateurs d'événements rares observés (utile pour interprétation des rhos/alphas)
|
||||||
|
result_df$has_BS_seen <- has_BS_seen
|
||||||
|
result_df$has_JP_seen <- has_JP_seen
|
||||||
|
|
||||||
# Ajout des paramètres estimés
|
# Ajout des paramètres estimés
|
||||||
for (i in 1:model_config$n_params) {
|
for (i in 1:model_config$n_params) {
|
||||||
result_df[[model_config$param_names[i]]] <- params[i]
|
result_df[[model_config$param_names[i]]] <- params[i]
|
||||||
result_df[[paste0("se_", model_config$param_names[i])]] <- param_se[i]
|
result_df[[paste0("se_", model_config$param_names[i])]] <- param_se[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
return(result_df)
|
return(result_df)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -352,36 +426,35 @@ fit_participant <- function(participant_data, model_config, n_runs = 5) {
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
fit_all_participants_all_models <- function(data, models_to_fit = NULL) {
|
fit_all_participants_all_models <- function(data, models_to_fit = NULL) {
|
||||||
|
|
||||||
model_configs <- get_model_configs()
|
model_configs <- get_model_configs()
|
||||||
|
|
||||||
if (!is.null(models_to_fit)) {
|
if (!is.null(models_to_fit)) {
|
||||||
model_configs <- model_configs[models_to_fit]
|
model_configs <- model_configs[models_to_fit]
|
||||||
}
|
}
|
||||||
|
|
||||||
participants <- unique(data$participant_id)
|
participants <- unique(data$participant_id)
|
||||||
|
|
||||||
all_results <- list()
|
all_results <- list()
|
||||||
|
|
||||||
for (model_name in names(model_configs)) {
|
for (model_name in names(model_configs)) {
|
||||||
cat("\n=== Fitting model:", model_name, "===\n")
|
cat("\n=== Fitting model:", model_name, "===\n")
|
||||||
|
|
||||||
model_config <- model_configs[[model_name]]
|
model_config <- model_configs[[model_name]]
|
||||||
|
|
||||||
model_results <- map_df(participants, function(pid) {
|
model_results <- map_df(participants, function(pid) {
|
||||||
cat(" Participant", pid, "\n")
|
cat(" Participant", pid, "\n")
|
||||||
|
|
||||||
participant_data <- data %>%
|
participant_data <- data %>%
|
||||||
filter(participant_id == pid) %>%
|
filter(participant_id == pid) %>%
|
||||||
arrange(trial)
|
arrange(trial)
|
||||||
|
|
||||||
fit <- fit_participant(participant_data, model_config)
|
fit <- fit_participant(participant_data, model_config)
|
||||||
fit %>% mutate(participant_id = pid, .before = 1)
|
fit %>% mutate(participant_id = pid, .before = 1)
|
||||||
})
|
})
|
||||||
|
|
||||||
all_results[[model_name]] <- model_results
|
all_results[[model_name]] <- model_results
|
||||||
}
|
}
|
||||||
|
|
||||||
return(all_results)
|
return(all_results)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -390,11 +463,10 @@ fit_all_participants_all_models <- function(data, models_to_fit = NULL) {
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
compare_nested_models <- function(all_results) {
|
compare_nested_models <- function(all_results) {
|
||||||
|
|
||||||
# Comparaison globale
|
# Comparaison globale
|
||||||
global_comparison <- map_df(names(all_results), function(model_name) {
|
global_comparison <- map_df(names(all_results), function(model_name) {
|
||||||
results <- all_results[[model_name]]
|
results <- all_results[[model_name]]
|
||||||
|
|
||||||
tibble(
|
tibble(
|
||||||
model = model_name,
|
model = model_name,
|
||||||
n_params = unique(results$n_params),
|
n_params = unique(results$n_params),
|
||||||
|
|
@ -407,10 +479,10 @@ compare_nested_models <- function(all_results) {
|
||||||
)
|
)
|
||||||
}) %>%
|
}) %>%
|
||||||
arrange(total_BIC)
|
arrange(total_BIC)
|
||||||
|
|
||||||
cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n")
|
cat("\n=== COMPARAISON GLOBALE DES MODÈLES ===\n")
|
||||||
print(global_comparison)
|
print(global_comparison)
|
||||||
|
|
||||||
# Meilleur modèle par participant (BIC)
|
# Meilleur modèle par participant (BIC)
|
||||||
best_models_per_participant <- map_df(names(all_results), function(model_name) {
|
best_models_per_participant <- map_df(names(all_results), function(model_name) {
|
||||||
all_results[[model_name]] %>%
|
all_results[[model_name]] %>%
|
||||||
|
|
@ -419,13 +491,13 @@ compare_nested_models <- function(all_results) {
|
||||||
group_by(participant_id) %>%
|
group_by(participant_id) %>%
|
||||||
slice_min(BIC, n = 1) %>%
|
slice_min(BIC, n = 1) %>%
|
||||||
ungroup()
|
ungroup()
|
||||||
|
|
||||||
cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n")
|
cat("\n=== MEILLEUR MODÈLE PAR PARTICIPANT (BIC) ===\n")
|
||||||
print(table(best_models_per_participant$model))
|
print(table(best_models_per_participant$model))
|
||||||
|
|
||||||
# Comparaison par paires de modèles emboîtés
|
# Comparaison par paires de modèles emboîtés
|
||||||
cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n")
|
cat("\n=== TESTS LRT POUR MODÈLES EMBOÎTÉS ===\n")
|
||||||
|
|
||||||
# Exemples de paires emboîtées
|
# Exemples de paires emboîtées
|
||||||
nested_pairs <- list(
|
nested_pairs <- list(
|
||||||
c("HOMOGENEOUS", "GAIN_LOSS"),
|
c("HOMOGENEOUS", "GAIN_LOSS"),
|
||||||
|
|
@ -437,15 +509,15 @@ compare_nested_models <- function(all_results) {
|
||||||
c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"),
|
c("REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE"),
|
||||||
c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX")
|
c("REE_LEARNING_BIASED_SIMPLE", "REE_LEARNING_BIASED_COMPLEX")
|
||||||
)
|
)
|
||||||
|
|
||||||
lrt_results <- map_df(nested_pairs, function(pair) {
|
lrt_results <- map_df(nested_pairs, function(pair) {
|
||||||
simple_model <- pair[1]
|
simple_model <- pair[1]
|
||||||
complex_model <- pair[2]
|
complex_model <- pair[2]
|
||||||
|
|
||||||
if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) {
|
if (simple_model %in% names(all_results) && complex_model %in% names(all_results)) {
|
||||||
simple_res <- all_results[[simple_model]]
|
simple_res <- all_results[[simple_model]]
|
||||||
complex_res <- all_results[[complex_model]]
|
complex_res <- all_results[[complex_model]]
|
||||||
|
|
||||||
# LRT par participant
|
# LRT par participant
|
||||||
lrt_df <- simple_res %>%
|
lrt_df <- simple_res %>%
|
||||||
select(participant_id, negLL_simple = negLL, converged_simple = converged) %>%
|
select(participant_id, negLL_simple = negLL, converged_simple = converged) %>%
|
||||||
|
|
@ -460,7 +532,7 @@ compare_nested_models <- function(all_results) {
|
||||||
p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE),
|
p_value = pchisq(LR_stat, df = df_diff, lower.tail = FALSE),
|
||||||
significant = p_value < 0.05
|
significant = p_value < 0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
tibble(
|
tibble(
|
||||||
simple_model = simple_model,
|
simple_model = simple_model,
|
||||||
complex_model = complex_model,
|
complex_model = complex_model,
|
||||||
|
|
@ -471,9 +543,9 @@ compare_nested_models <- function(all_results) {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
print(lrt_results)
|
print(lrt_results)
|
||||||
|
|
||||||
list(
|
list(
|
||||||
global_comparison = global_comparison,
|
global_comparison = global_comparison,
|
||||||
best_models_per_participant = best_models_per_participant,
|
best_models_per_participant = best_models_per_participant,
|
||||||
|
|
@ -489,7 +561,7 @@ compare_nested_models <- function(all_results) {
|
||||||
plot_model_comparison <- function(comparison_results) {
|
plot_model_comparison <- function(comparison_results) {
|
||||||
require(ggplot2)
|
require(ggplot2)
|
||||||
require(patchwork)
|
require(patchwork)
|
||||||
|
|
||||||
# 1. Comparaison BIC globale
|
# 1. Comparaison BIC globale
|
||||||
p1 <- comparison_results$global_comparison %>%
|
p1 <- comparison_results$global_comparison %>%
|
||||||
mutate(model = fct_reorder(model, total_BIC)) %>%
|
mutate(model = fct_reorder(model, total_BIC)) %>%
|
||||||
|
|
@ -497,10 +569,12 @@ plot_model_comparison <- function(comparison_results) {
|
||||||
geom_col() +
|
geom_col() +
|
||||||
coord_flip() +
|
coord_flip() +
|
||||||
theme_minimal() +
|
theme_minimal() +
|
||||||
labs(title = "Comparaison globale des modèles (BIC)",
|
labs(
|
||||||
subtitle = "Plus bas = meilleur") +
|
title = "Comparaison globale des modèles (BIC)",
|
||||||
|
subtitle = "Plus bas = meilleur"
|
||||||
|
) +
|
||||||
theme(legend.position = "none")
|
theme(legend.position = "none")
|
||||||
|
|
||||||
# 2. Meilleur modèle par participant
|
# 2. Meilleur modèle par participant
|
||||||
p2 <- comparison_results$best_models_per_participant %>%
|
p2 <- comparison_results$best_models_per_participant %>%
|
||||||
count(model) %>%
|
count(model) %>%
|
||||||
|
|
@ -509,27 +583,34 @@ plot_model_comparison <- function(comparison_results) {
|
||||||
geom_col() +
|
geom_col() +
|
||||||
coord_flip() +
|
coord_flip() +
|
||||||
theme_minimal() +
|
theme_minimal() +
|
||||||
labs(title = "Meilleur modèle par participant",
|
labs(
|
||||||
y = "Nombre de participants") +
|
title = "Meilleur modèle par participant",
|
||||||
|
y = "Nombre de participants"
|
||||||
|
) +
|
||||||
theme(legend.position = "none")
|
theme(legend.position = "none")
|
||||||
|
|
||||||
# 3. Tests LRT
|
# 3. Tests LRT
|
||||||
if (nrow(comparison_results$lrt_results) > 0) {
|
if (nrow(comparison_results$lrt_results) > 0) {
|
||||||
p3 <- comparison_results$lrt_results %>%
|
p3 <- comparison_results$lrt_results %>%
|
||||||
mutate(comparison = paste(simple_model, "→", complex_model)) %>%
|
mutate(comparison = paste(simple_model, "→", complex_model)) %>%
|
||||||
ggplot(aes(x = fct_reorder(comparison, pct_significant),
|
ggplot(aes(
|
||||||
y = pct_significant)) +
|
x = fct_reorder(comparison, pct_significant),
|
||||||
|
y = pct_significant
|
||||||
|
)) +
|
||||||
geom_col(fill = "steelblue") +
|
geom_col(fill = "steelblue") +
|
||||||
geom_hline(yintercept = 50, linetype = "dashed", color = "red") +
|
geom_hline(yintercept = 50, linetype = "dashed", color = "red") +
|
||||||
coord_flip() +
|
coord_flip() +
|
||||||
theme_minimal() +
|
theme_minimal() +
|
||||||
labs(title = "Tests LRT entre modèles emboîtés",
|
labs(
|
||||||
y = "% participants avec p < 0.05",
|
title = "Tests LRT entre modèles emboîtés",
|
||||||
x = "Comparaison")
|
y = "% participants avec p < 0.05",
|
||||||
|
x = "Comparaison"
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
p3 <- ggplot() + theme_void()
|
p3 <- ggplot() +
|
||||||
|
theme_void()
|
||||||
}
|
}
|
||||||
|
|
||||||
# 4. Convergence par modèle
|
# 4. Convergence par modèle
|
||||||
p4 <- map_df(names(comparison_results$all_results), function(model_name) {
|
p4 <- map_df(names(comparison_results$all_results), function(model_name) {
|
||||||
comparison_results$all_results[[model_name]] %>%
|
comparison_results$all_results[[model_name]] %>%
|
||||||
|
|
@ -540,9 +621,11 @@ plot_model_comparison <- function(comparison_results) {
|
||||||
scale_y_log10() +
|
scale_y_log10() +
|
||||||
coord_flip() +
|
coord_flip() +
|
||||||
theme_minimal() +
|
theme_minimal() +
|
||||||
labs(title = "Convergence par modèle",
|
labs(
|
||||||
y = "Range negLL (log scale)")
|
title = "Convergence par modèle",
|
||||||
|
y = "Range negLL (log scale)"
|
||||||
|
)
|
||||||
|
|
||||||
(p1 | p2) / (p3 | p4)
|
(p1 | p2) / (p3 | p4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -553,14 +636,15 @@ plot_model_comparison <- function(comparison_results) {
|
||||||
# Charger vos données
|
# Charger vos données
|
||||||
# data <- read_csv("votre_fichier.csv")
|
# data <- read_csv("votre_fichier.csv")
|
||||||
# Colonnes requises: participant_id, trial, choice, reward
|
# Colonnes requises: participant_id, trial, choice, reward
|
||||||
|
source("load_data.R")
|
||||||
|
|
||||||
# Estimation de tous les modèles
|
# Estimation de tous les modèles
|
||||||
# all_results <- fit_all_participants_all_models(data)
|
# all_results <- fit_all_participants_all_models(data)
|
||||||
|
fit_all_participants_all_models(data %>% filter(participant_id == "qfmtmjjy"))
|
||||||
# Ou seulement certains modèles
|
# Ou seulement certains modèles
|
||||||
# all_results <- fit_all_participants_all_models(
|
# all_results <- fit_all_participants_all_models(
|
||||||
# data,
|
# data,
|
||||||
# models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE",
|
# models_to_fit = c("HOMOGENEOUS", "GAIN_LOSS", "REE_BIASED_SIMPLE",
|
||||||
# "REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE")
|
# "REE_LEARNING_SIMPLE", "REE_LEARNING_BIASED_SIMPLE")
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
@ -572,8 +656,8 @@ plot_model_comparison <- function(comparison_results) {
|
||||||
|
|
||||||
# Sauvegarder les résultats
|
# Sauvegarder les résultats
|
||||||
# for (model_name in names(all_results)) {
|
# for (model_name in names(all_results)) {
|
||||||
# write_csv(all_results[[model_name]],
|
# write_csv(all_results[[model_name]],
|
||||||
# paste0("results_", model_name, ".csv"))
|
# paste0("results_", model_name, ".csv"))
|
||||||
# }
|
# }
|
||||||
# write_csv(comparison$global_comparison, "global_comparison.csv")
|
# write_csv(comparison$global_comparison, "global_comparison.csv")
|
||||||
# write_csv(comparison$best_models_per_participant, "best_models.csv")
|
# write_csv(comparison$best_models_per_participant, "best_models.csv")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue