Script to measure impact of VEM parameters
This commit is contained in:
parent
5e6de16d7b
commit
871dc734b8
1 changed files with 141 additions and 0 deletions
141
code/analysis/investigating/impact_of_vem_max_steps.R
Normal file
141
code/analysis/investigating/impact_of_vem_max_steps.R
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
## ----libraries-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
library(colSBM)
|
||||
library(aricode)
|
||||
library(here)
|
||||
|
||||
|
||||
## ----constants-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
base_folder <- here("code", "results", "investigating", "vem_steps")
|
||||
|
||||
if (!dir.exists(base_folder)) {
|
||||
dir.create(base_folder, recursive = TRUE)
|
||||
}
|
||||
|
||||
net_seed <- 0
|
||||
test_seeds <- c(12, 3)
|
||||
epsilons <- c(0.1, 0.4)
|
||||
vem_steps <- seq(10, 300, by = 40)
|
||||
|
||||
conditions <- expand.grid(
|
||||
seeds = test_seeds,
|
||||
epsilons = epsilons,
|
||||
vem_steps = vem_steps
|
||||
)
|
||||
|
||||
|
||||
base_alpha <- matrix(rep(0.3, 9L), nrow = 3L)
|
||||
pi <- c(0.3, 0.2, 0.5)
|
||||
rho <- c(0.55, 0.15, 0.3)
|
||||
M <- 10L
|
||||
nr <- c(rep(30L, M / 2L), rep(95L, M / 2L))
|
||||
nc <- c(rep(40L, M / 2L), rep(70L, M / 2L))
|
||||
|
||||
|
||||
## ----functions-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
generate_net <- function(eps, net_seed = 0) {
|
||||
set.seed(net_seed)
|
||||
as_alpha <- base_alpha + matrix(
|
||||
c(
|
||||
eps, -eps / 2L, -eps / 2L,
|
||||
-eps / 2L, eps, -eps / 2L,
|
||||
-eps / 2L, -eps / 2L, eps
|
||||
),
|
||||
nrow = 3L
|
||||
)
|
||||
|
||||
cp_alpha <- base_alpha + matrix(
|
||||
c(
|
||||
3L * eps / 2L, eps, eps / 2L,
|
||||
eps, eps / 2L, 0L,
|
||||
eps / 2L, 0L, -eps / 2L
|
||||
),
|
||||
nrow = 3L
|
||||
)
|
||||
|
||||
dis_alpha <- base_alpha + matrix(
|
||||
c(
|
||||
-eps / 2L, eps, eps,
|
||||
eps, -eps / 2L, eps,
|
||||
eps, eps, -eps / 2L
|
||||
),
|
||||
nrow = 3L
|
||||
)
|
||||
|
||||
collection <- c(
|
||||
generate_bipartite_collection(
|
||||
nr = nr, nc = nc,
|
||||
pi = pi, rho = rho,
|
||||
alpha = as_alpha, M = M
|
||||
),
|
||||
generate_bipartite_collection(
|
||||
nr = nr, nc = nc,
|
||||
pi = pi, rho = rho,
|
||||
alpha = cp_alpha, M = M
|
||||
),
|
||||
generate_bipartite_collection(
|
||||
nr = nr, nc = nc,
|
||||
pi = pi, rho = rho,
|
||||
alpha = dis_alpha, M = M
|
||||
)
|
||||
)
|
||||
names(collection) <- c(
|
||||
0 + seq(0, M %/% 2), 0 + seq(M %/% 2 + 1, M - 1),
|
||||
10 + seq(0, M %/% 2), 10 + seq(M %/% 2 + 1, M - 1),
|
||||
20 + seq(0, M %/% 2), 20 + seq(M %/% 2 + 1, M - 1)
|
||||
)
|
||||
|
||||
return(collection)
|
||||
}
|
||||
|
||||
|
||||
## ----gen_networks--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
list_collections <- lapply(epsilons, function(eps) {
|
||||
generate_net(eps = eps, net_seed = net_seed)
|
||||
})
|
||||
names(list_collections) <- epsilons
|
||||
|
||||
true_clustering <- c(rep(1, M), rep(2, M), rep(3, M))
|
||||
|
||||
|
||||
## ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
results <- lapply(seq_len(nrow(conditions)), function(idx) {
|
||||
current_seed <- conditions[["seeds"]][idx]
|
||||
eps <- conditions[["epsilons"]][idx]
|
||||
max_vem_steps <- conditions[["vem_steps"]][idx]
|
||||
|
||||
collection <- list_collections[[as.character(eps)]]
|
||||
|
||||
set.seed(current_seed)
|
||||
start_time <- Sys.time()
|
||||
clust <- clusterize_bipartite_networks(
|
||||
netlist = collection, net_id = names(collection),
|
||||
colsbm_model = "iid", fit_opts = list(max_vem_steps = max_vem_steps),
|
||||
global_opts = list(
|
||||
verbosity = 2L,
|
||||
nb_cores = parallelly::availableCores(omit = 1L)
|
||||
)
|
||||
)
|
||||
stop_time <- Sys.time()
|
||||
|
||||
elapsed_time <- stop_time - start_time
|
||||
|
||||
unlisted_best_partition <- extract_best_bipartite_partition(clust)
|
||||
|
||||
if (!is.list(unlisted_best_partition)) {
|
||||
unlisted_best_partition <- list(unlisted_best_partition)
|
||||
}
|
||||
|
||||
clustering_vec <- sort(unlist(lapply(seq_len(length(unlisted_best_partition)), function(idx) {
|
||||
ids_nets <- as.numeric(unlisted_best_partition[[idx]]$net_id)
|
||||
names(ids_nets) <- rep(idx, length(ids_nets))
|
||||
ids_nets
|
||||
})))
|
||||
|
||||
cluster_membership <- as.numeric(names(clustering_vec))
|
||||
|
||||
ari <- ARI(cluster_membership, true_clustering)
|
||||
|
||||
data.frame(eps = eps, seed = current_seed, max_vem_steps = max_vem_steps, ari = ari, elapsed_time = elapsed_time, start_time = start_time, stop_time = stop_time, clustering = matrix(cluster_membership, nrow = 1L))
|
||||
})
|
||||
|
||||
to_save <- do.call(rbind, results)
|
||||
Loading…
Add table
Reference in a new issue