Script to measure impact of VEM parameters

This commit is contained in:
Louis Lacoste 2024-06-11 17:48:02 +02:00
parent 5e6de16d7b
commit 871dc734b8

View file

@ -0,0 +1,141 @@
## ----libraries-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
library(colSBM)
library(aricode)
library(here)
## ----constants-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
base_folder <- here("code", "results", "investigating", "vem_steps")
if (!dir.exists(base_folder)) {
dir.create(base_folder, recursive = TRUE)
}
net_seed <- 0
test_seeds <- c(12, 3)
epsilons <- c(0.1, 0.4)
vem_steps <- seq(10, 300, by = 40)
conditions <- expand.grid(
seeds = test_seeds,
epsilons = epsilons,
vem_steps = vem_steps
)
base_alpha <- matrix(rep(0.3, 9L), nrow = 3L)
pi <- c(0.3, 0.2, 0.5)
rho <- c(0.55, 0.15, 0.3)
M <- 10L
nr <- c(rep(30L, M / 2L), rep(95L, M / 2L))
nc <- c(rep(40L, M / 2L), rep(70L, M / 2L))
## ----functions-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
generate_net <- function(eps, net_seed = 0) {
set.seed(net_seed)
as_alpha <- base_alpha + matrix(
c(
eps, -eps / 2L, -eps / 2L,
-eps / 2L, eps, -eps / 2L,
-eps / 2L, -eps / 2L, eps
),
nrow = 3L
)
cp_alpha <- base_alpha + matrix(
c(
3L * eps / 2L, eps, eps / 2L,
eps, eps / 2L, 0L,
eps / 2L, 0L, -eps / 2L
),
nrow = 3L
)
dis_alpha <- base_alpha + matrix(
c(
-eps / 2L, eps, eps,
eps, -eps / 2L, eps,
eps, eps, -eps / 2L
),
nrow = 3L
)
collection <- c(
generate_bipartite_collection(
nr = nr, nc = nc,
pi = pi, rho = rho,
alpha = as_alpha, M = M
),
generate_bipartite_collection(
nr = nr, nc = nc,
pi = pi, rho = rho,
alpha = cp_alpha, M = M
),
generate_bipartite_collection(
nr = nr, nc = nc,
pi = pi, rho = rho,
alpha = dis_alpha, M = M
)
)
names(collection) <- c(
0 + seq(0, M %/% 2), 0 + seq(M %/% 2 + 1, M - 1),
10 + seq(0, M %/% 2), 10 + seq(M %/% 2 + 1, M - 1),
20 + seq(0, M %/% 2), 20 + seq(M %/% 2 + 1, M - 1)
)
return(collection)
}
## ----gen_networks--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
list_collections <- lapply(epsilons, function(eps) {
generate_net(eps = eps, net_seed = net_seed)
})
names(list_collections) <- epsilons
true_clustering <- c(rep(1, M), rep(2, M), rep(3, M))
## ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
results <- lapply(seq_len(nrow(conditions)), function(idx) {
current_seed <- conditions[["seeds"]][idx]
eps <- conditions[["epsilons"]][idx]
max_vem_steps <- conditions[["vem_steps"]][idx]
collection <- list_collections[[as.character(eps)]]
set.seed(current_seed)
start_time <- Sys.time()
clust <- clusterize_bipartite_networks(
netlist = collection, net_id = names(collection),
colsbm_model = "iid", fit_opts = list(max_vem_steps = max_vem_steps),
global_opts = list(
verbosity = 2L,
nb_cores = parallelly::availableCores(omit = 1L)
)
)
stop_time <- Sys.time()
elapsed_time <- stop_time - start_time
unlisted_best_partition <- extract_best_bipartite_partition(clust)
if (!is.list(unlisted_best_partition)) {
unlisted_best_partition <- list(unlisted_best_partition)
}
clustering_vec <- sort(unlist(lapply(seq_len(length(unlisted_best_partition)), function(idx) {
ids_nets <- as.numeric(unlisted_best_partition[[idx]]$net_id)
names(ids_nets) <- rep(idx, length(ids_nets))
ids_nets
})))
cluster_membership <- as.numeric(names(clustering_vec))
ari <- ARI(cluster_membership, true_clustering)
data.frame(eps = eps, seed = current_seed, max_vem_steps = max_vem_steps, ari = ari, elapsed_time = elapsed_time, start_time = start_time, stop_time = stop_time, clustering = matrix(cluster_membership, nrow = 1L))
})
to_save <- do.call(rbind, results)