mia-stage-2024/code/applications/utils.R

142 lines
4.8 KiB
R

#' Select the n most recent files
#'
#' @param data_folder The folder in which data files are located.
#' @param n The number of files to return. Defaults to 4.
#'
#' @details This function return the n most recent files and returns a
#' warning if the number wanted `n` is larger than the number of files.
#'
#' @return A vector of size `n` with the file path.
get_recent_files <- function(data_folder, n = 4, pattern = NULL) {
files_info <- file.info(file.path(data_folder, list.files(data_folder,
include.dirs = FALSE, pattern = pattern
)))
files_info[["filepath"]] <- file.path(data_folder, list.files(data_folder,
include.dirs = FALSE, pattern = pattern
))
files_info <- sort_by(files_info, files_info[["ctime"]], decreasing = TRUE)
if (n > nrow(files_info)) {
warning(
"n = ", n,
" is too large ! It should be at most ",
nrow(files_info)
)
}
return(head(files_info[["filepath"]], n = n))
}
#' Identify models
identify_models <- function(files_vec, pattern = "(iid|pirho|pi|rho)") {
names(files_vec) <- stringr::str_extract(
string = files_vec,
pattern = pattern
)
return(files_vec)
}
build_graph_size_dataframe <- function(collection_list) {
if (!is.list(collection_list)) {
return(data.frame(
collection_id = factor(1L),
M = collection_list[["M"]],
net_id = factor(collection_list[["net_id"]]),
nr = collection_list[["n"]][[1]],
nc = collection_list[["n"]][[2]],
Qr = collection_list[["Q"]][[1]],
Qc = collection_list[["Q"]][[2]]
))
}
do.call("rbind", lapply(seq_len(length(collection_list)), function(idx) {
collection <- collection_list[[idx]]
data.frame(
collection_id = factor(idx),
M = collection[["M"]],
net_id = factor(collection[["net_id"]]),
nr = collection[["n"]][[1]],
nc = collection[["n"]][[2]],
Qr = collection[["Q"]][[1]],
Qc = collection[["Q"]][[2]]
)
}))
}
extract_clustering <- function(clustering) {
partition <- colSBM::extract_best_partition(
l = clustering,
unnest = TRUE
)
if (!is.list(partition)) {
partition <- list(partition)
}
out <- unlist(sapply(seq_len(length(partition)), function(idx) {
clust_vec <- rep(idx, partition[[idx]][["M"]])
names(clust_vec) <- partition[[idx]][["net_id"]]
clust_vec
}))
if (is.matrix(out)) {
nm <- rownames(out)
out <- setNames(c(out), nm)
}
return(out)
}
compute_alpha_penalty <- function(model) {
N_M <- sum(model[["n"]][[1]] * model[["n"]][[2]])
if (model[["free_mixture_row"]] || model[["free_mixture_col"]]) {
return(sum(model$Calpha) * log(N_M))
} else {
return(model$Q[1] * model$Q[2] * log(N_M))
}
}
compute_pi_penalty <- function(model) {
if (model$free_mixture_row) {
Cpi <- model$Cpi[[1]]
pi1_penalty <- sum((colSums(Cpi) - 1) * log(model$n[[1]]))
} else {
# To account for the possibility of the other free_mixture we store a
# temporary support full of TRUE
Cpi <- matrix(TRUE, model$Q[1], model$M) # Cpi must be Q x M !
# If there is no free mixture on the cols
pi1_penalty <- (model$Q[1] - 1) * log(sum(model$n[[1]]))
}
return(pi1_penalty)
}
compute_rho_penalty <- function(model) {
if (model$free_mixture_col) {
Cpi <- model$Cpi[[2]]
pi2_penalty <- sum((colSums(Cpi) - 1) * log(model$n[[2]]))
} else {
# To account for the possibility of the other free_mixture we store a
# temporary support full of TRUE
Cpi <- matrix(TRUE, model$Q[2], model$M) # Cpi must be Q x M !
# If there is no free mixture on the cols
pi2_penalty <- (model$Q[2] - 1) * log(sum(model$n[[2]]))
}
return(pi2_penalty)
}
compute_S_penalty <- function(model, dim) {
if ((dim == 1 && model$free_mixture_row) || (dim == 2 && model$free_mixture_col)) {
log_p_Q <- -model$M * log(model$Q[dim]) -
sum(log(choose(
rep(model$Q[dim], model$M), colSums(model$Cpi[[dim]])
)))
return(-2 * log_p_Q)
} else {
return(0)
}
}
get_details_bicl <- function(model) {
data.frame(
vbound = model$compute_vbound(),
pen_pi = 0.5 * compute_pi_penalty(model),
pen_S_pi = 0.5 * compute_S_penalty(model, dim = 1),
pen_rho = 0.5 * compute_rho_penalty(model),
pen_S_rho = 0.5 * compute_S_penalty(model, dim = 2),
pen_alpha = 0.5 * compute_alpha_penalty(model),
package_bicl = model$compute_BICL()
) |> dplyr::mutate(computed_bicl = vbound - sum(dplyr::across(pen_pi:pen_alpha)))
}