#!/usr/bin/env Rscript
args = commandArgs(trailingOnly=TRUE)

metrics_fname = args[1]
orig_fname = args[2]
embed_fname = args[3]
n_idx = as.integer(args[4]) 

# Compute the following metrics:
# - sam_x: structure alignment metric for x data (the larger, the better)
# - sam_y: structure alignment metric for y data (the larger, the better)
# - slt_mix: mixing via Silhouette width (the larger, the better)
# - slt_clust: quality of embeddings for clustering via Silhouette width (the larger, the better)
# - slt_f1: an integrated metric using both slt_mix and slt_clust (the larger, the better)
# - ari_mix: mixing via adjusted random index (the larger, the better)
# - ari_clust: quality of embeddings for clustering via adjusted random index (the larger, the better)
# - lisi_mix: mixing via Local Inverse Simpson’s Index (LISI) (the larger, the better)
# - lisi_clust: quality of embeddings for clustering via LISI (the larger, the better)
# - kbet: mixing via k-nearest neighbour batch effect test (kBET) (the larger, the better)
# - avg_mix: mixing metric via two sample test, averaged over all clusters (the larger, the better)
setwd("./")
source("metrics.R")

# load existing metrics
metrics = read_csv(metrics_fname, col_types=cols())


# calculate structure alignment metrics
print(paste0(format(Sys.Date(), "%c"), ': calculating structure alignment metrics...'))
sam_x = sam(orig_fname=orig_fname, embed_fname=embed_fname,
            n_idx=n_idx, data_idx='x')
sam_y= sam(orig_fname=orig_fname, embed_fname=embed_fname,
           n_idx=n_idx, data_idx='y')
metrics = metrics %>% add_column(sam_x=sam_x) %>% add_column(sam_y=sam_y)

# calculate Silhouette width
print(paste0(format(Sys.Date(), "%c"), ': calculating Silhouette width...'))
slt_res = slt(orig_fname=orig_fname, embed_fname=embed_fname, n_idx=n_idx)
metrics = metrics %>% add_column(slt_mix=slt_res[, 1]) %>% add_column(slt_clust=slt_res[, 2]) %>% add_column(slt_f1=slt_res[, 3])
# calculate ARI
print(paste0(format(Sys.Date(), "%c"), ': calculating adjusted random index...'))
ari_res = ari(orig_fname=orig_fname, embed_fname=embed_fname, n_idx=n_idx)
metrics = metrics %>% add_column(ari_mix=ari_res[, 1]) %>% add_column(ari_clust=ari_res[, 2]) %>% add_column(ari_f1=ari_res[, 3])

write_csv(metrics, metrics_fname)
