#!/usr/bin/env Rscript
# ============================================================================
# BioF3 Module 05: Integration Concepts with Real PBMC 3k-Derived Views
# ============================================================================

options(stringsAsFactors = FALSE)

args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/single-cell/sc05_integration2_sci.R", mustWork = FALSE)
}

script_dir <- dirname(script_path)
project_root <- if (basename(script_dir) == "scripts" && basename(dirname(script_dir)) == "static") {
  dirname(dirname(script_dir))
} else {
  dirname(script_dir)
}

data_root <- Sys.getenv("BIOF3_DATA_DIR", file.path(path.expand("~"), "biof3-data"))
pbmc_dir <- file.path(data_root, "pbmc3k")
output_dir <- Sys.getenv("BIOF3_OUTPUT_DIR", file.path(project_root, "static", "img", "tutorial", "modules", "module05"))

dir.create(pbmc_dir, recursive = TRUE, showWarnings = FALSE)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

required_packages <- c("Matrix", "ggplot2", "dplyr", "tidyr", "patchwork", "scales", "uwot", "pheatmap")
for (pkg in required_packages) {
  if (!requireNamespace(pkg, quietly = TRUE)) {
    install.packages(pkg, repos = "https://cloud.r-project.org/")
  }
}

library(Matrix)
library(ggplot2)
library(dplyr)
library(tidyr)
library(patchwork)
library(scales)
library(uwot)
library(pheatmap)

theme_sci <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text = element_text(color = "black"),
      axis.title = element_text(face = "bold", color = "black"),
      plot.title = element_text(face = "bold", hjust = 0.5, color = "black"),
      plot.subtitle = element_text(color = "gray30", hjust = 0.5),
      legend.title = element_text(face = "bold"),
      plot.background = element_rect(fill = "white", color = NA),
      strip.background = element_blank(),
      strip.text = element_text(face = "bold", color = "black")
    )
}

pbmc_url <- paste0(
  "https://cf.10xgenomics.com/samples/cell-exp/1.1.0/pbmc3k/",
  "pbmc3k_filtered_gene_bc_matrices.tar.gz"
)
pbmc_tar <- file.path(pbmc_dir, "pbmc3k_filtered_gene_bc_matrices.tar.gz")
matrix_dir <- file.path(pbmc_dir, "filtered_gene_bc_matrices", "hg19")

download_pbmc3k <- function() {
  if (!file.exists(pbmc_tar)) {
    message("Downloading PBMC 3k data from 10x Genomics...")
    download.file(pbmc_url, destfile = pbmc_tar, mode = "wb", quiet = FALSE)
  }
  if (!file.exists(file.path(matrix_dir, "matrix.mtx"))) {
    message("Extracting PBMC 3k archive...")
    utils::untar(pbmc_tar, exdir = pbmc_dir)
  }
}

read_pbmc3k <- function() {
  download_pbmc3k()
  counts <- Matrix::readMM(file.path(matrix_dir, "matrix.mtx"))
  genes <- read.delim(file.path(matrix_dir, "genes.tsv"), header = FALSE, sep = "\t")
  barcodes <- read.delim(file.path(matrix_dir, "barcodes.tsv"), header = FALSE, sep = "\t")
  rownames(counts) <- make.unique(genes[[2]])
  colnames(counts) <- barcodes[[1]]
  as(counts, "CsparseMatrix")
}

message("=== Loading real PBMC 3k matrix ===")
counts <- read_pbmc3k()

n_counts <- Matrix::colSums(counts)
n_features <- Matrix::colSums(counts > 0)
gene_summary <- data.frame(
  gene = rownames(counts),
  mean_counts = as.numeric(Matrix::rowMeans(counts)),
  variance = as.numeric(Matrix::rowMeans(counts^2) - Matrix::rowMeans(counts)^2),
  detected_cells = as.numeric(Matrix::rowSums(counts > 0))
) %>%
  filter(detected_cells >= 10, mean_counts > 0) %>%
  mutate(dispersion = variance / mean_counts)

hvg <- gene_summary %>%
  filter(!grepl("^MT-|^RPL|^RPS", gene)) %>%
  arrange(desc(dispersion)) %>%
  slice_head(n = 1200)

norm_counts <- t(t(counts[hvg$gene, ]) / n_counts * 10000)
log_norm <- log1p(norm_counts)
scaled <- t(scale(t(as.matrix(log_norm))))
scaled[is.na(scaled)] <- 0

pca <- prcomp(t(scaled), center = FALSE, scale. = FALSE, rank. = 30)
pca_df <- as.data.frame(pca$x[, 1:30])
pca_df$cell <- rownames(pca_df)
pca_df <- pca_df %>%
  arrange(cell) %>%
  mutate(
    barcode_subset = factor(paste("Subset", ntile(row_number(), 3)), levels = paste("Subset", 1:3)),
    nCount_RNA = as.numeric(n_counts[cell]),
    nFeature_RNA = as.numeric(n_features[cell])
  )

initial_centers <- as.matrix(pca_df[round(seq(1, nrow(pca_df), length.out = 4)), paste0("PC", 1:10)])
km <- kmeans(as.matrix(pca_df[, paste0("PC", 1:10)]), centers = initial_centers, iter.max = 100)
pca_df$cluster <- factor(km$cluster)

umap_coords <- uwot::umap(
  as.matrix(pca_df[, paste0("PC", 1:20)]),
  n_neighbors = 20,
  min_dist = 0.35,
  metric = "cosine",
  n_threads = 1,
  verbose = FALSE,
  ret_model = FALSE
)

embedding <- data.frame(
  cell = pca_df$cell,
  UMAP1 = umap_coords[, 1],
  UMAP2 = umap_coords[, 2],
  barcode_subset = pca_df$barcode_subset,
  cluster = pca_df$cluster,
  nCount_RNA = pca_df$nCount_RNA,
  nFeature_RNA = pca_df$nFeature_RNA
)

marker_sets <- list(
  "T cells" = c("IL7R", "CCR7", "LTB"),
  "B cells" = c("MS4A1", "CD79A", "CD74"),
  "Monocytes" = c("LYZ", "S100A8", "S100A9"),
  "NK cells" = c("NKG7", "GNLY", "GZMB")
)

cluster_marker_score <- bind_rows(lapply(names(marker_sets), function(cell_type) {
  genes <- intersect(marker_sets[[cell_type]], rownames(log_norm))
  score <- Matrix::colMeans(log_norm[genes, , drop = FALSE])
  data.frame(cell = colnames(log_norm), cell_type = cell_type, score = as.numeric(score))
})) %>%
  left_join(select(embedding, cell, cluster), by = "cell") %>%
  group_by(cluster, cell_type) %>%
  summarise(score = mean(score), .groups = "drop")

cluster_annotation <- cluster_marker_score %>%
  group_by(cluster) %>%
  slice_max(score, n = 1, with_ties = FALSE) %>%
  ungroup()

embedding <- embedding %>%
  left_join(select(cluster_annotation, cluster, cell_type), by = "cluster")

message("=== Generating real-data figures ===")

batch_sources <- data.frame(
  category = c("Technical", "Technical", "Technical", "Biological", "Biological", "Biological"),
  source = c("Sequencing run", "Operator", "Reagent lot", "Donor", "Tissue site", "Treatment state"),
  scope = c("Metadata to inspect", "Metadata to inspect", "Metadata to inspect", "Biology to preserve", "Biology to preserve", "Biology to preserve")
) %>%
  mutate(source = factor(source, levels = rev(source)), value = 1)

p1 <- ggplot(batch_sources, aes(x = value, y = source, fill = category)) +
  geom_col(color = "black", linewidth = 0.25, width = 0.68) +
  geom_text(aes(label = scope), hjust = -0.06, size = 3.4, fontface = "bold") +
  scale_x_continuous(limits = c(0, 1.55), breaks = NULL) +
  scale_fill_manual(values = c("Technical" = "#E64B35", "Biological" = "#4DBBD5")) +
  labs(title = "Integration metadata checks before combining real samples", x = NULL, y = NULL, fill = "Category") +
  theme_sci()
ggsave(file.path(output_dir, "01-batch-sources.png"), p1, width = 10, height = 6, dpi = 300, bg = "white")

methods <- data.frame(
  method = c("CCA", "Harmony", "scVI", "LIGER", "ComBat"),
  input = c("Anchors", "PCA embeddings", "Raw counts", "Matrix factorization", "Expression matrix"),
  typical_use = c("Seurat workflows", "Fast correction", "Deep generative model", "Cross-dataset factors", "Bulk-like correction")
) %>%
  mutate(method = factor(method, levels = method), value = seq_along(method))

p2 <- ggplot(methods, aes(x = method, y = value, fill = method)) +
  geom_col(color = "black", linewidth = 0.25, width = 0.65) +
  geom_text(aes(label = input), vjust = -0.45, size = 3.2, fontface = "bold") +
  scale_y_continuous(breaks = NULL, expand = expansion(mult = c(0, 0.22))) +
  scale_fill_brewer(palette = "Set2") +
  labs(title = "Common integration methods and their expected inputs", x = "Method", y = NULL) +
  theme_sci() +
  theme(legend.position = "none")
ggsave(file.path(output_dir, "02-methods-comparison.png"), p2, width = 10, height = 6, dpi = 300, bg = "white")

p3 <- ggplot(pca_df, aes(x = PC1, y = PC2, color = barcode_subset)) +
  geom_point(size = 1.2, alpha = 0.68) +
  scale_color_brewer(palette = "Dark2") +
  labs(title = "Real PBMC 3k subsets in PCA space", subtitle = "Deterministic barcode subsets from the public matrix", x = "PC1", y = "PC2", color = "Subset") +
  theme_sci()
ggsave(file.path(output_dir, "03-before-integration.png"), p3, width = 10, height = 7, dpi = 300, bg = "white")

p4 <- ggplot(embedding, aes(x = UMAP1, y = UMAP2, color = barcode_subset)) +
  geom_point(size = 1.2, alpha = 0.68) +
  scale_color_brewer(palette = "Dark2") +
  labs(title = "Real PBMC 3k subsets in UMAP space", subtitle = "UMAP from real PCA coordinates", x = "UMAP1", y = "UMAP2", color = "Subset") +
  theme_sci()
ggsave(file.path(output_dir, "04-after-integration.png"), p4, width = 10, height = 7, dpi = 300, bg = "white")

p5 <- (p3 + ggtitle("PCA view")) | (p4 + ggtitle("UMAP view"))
p5 <- p5 + plot_annotation(title = "Real PBMC 3k subset views from computed coordinates")
ggsave(file.path(output_dir, "05-comparison.png"), p5, width = 16, height = 7, dpi = 300, bg = "white")

p6 <- ggplot(embedding, aes(x = UMAP1, y = UMAP2, color = cell_type)) +
  geom_point(size = 1.2, alpha = 0.72) +
  scale_color_brewer(palette = "Set2") +
  labs(title = "Marker-based PBMC cell type view", x = "UMAP1", y = "UMAP2", color = "Cell type") +
  theme_sci()
ggsave(file.path(output_dir, "06-cell-types.png"), p6, width = 10, height = 7, dpi = 300, bg = "white")

subset_qc <- embedding %>%
  group_by(barcode_subset) %>%
  summarise(
    median_umi = median(nCount_RNA),
    median_genes = median(nFeature_RNA),
    cells = n(),
    .groups = "drop"
  )

qc_long <- subset_qc %>%
  pivot_longer(-barcode_subset, names_to = "metric", values_to = "value") %>%
  group_by(metric) %>%
  mutate(relative = value / mean(value)) %>%
  ungroup()

p7a <- ggplot(qc_long, aes(x = barcode_subset, y = relative, fill = metric)) +
  geom_col(position = "dodge", color = "black", linewidth = 0.25) +
  geom_hline(yintercept = 1, linetype = "dashed", color = "gray40") +
  scale_y_continuous(labels = label_percent()) +
  scale_fill_manual(values = c("#00A087", "#3C5488", "#F39B7F")) +
  labs(title = "Subset QC balance", x = NULL, y = "Relative to mean", fill = "Metric") +
  theme_sci()

mixing <- embedding %>%
  count(cluster, barcode_subset) %>%
  group_by(cluster) %>%
  mutate(fraction = n / sum(n)) %>%
  ungroup()

p7b <- ggplot(mixing, aes(x = cluster, y = fraction, fill = barcode_subset)) +
  geom_col(color = "black", linewidth = 0.2) +
  scale_y_continuous(labels = label_percent()) +
  scale_fill_brewer(palette = "Dark2") +
  labs(title = "Subset composition by cluster", x = "Cluster", y = "Fraction", fill = "Subset") +
  theme_sci()

p7 <- p7a | p7b
ggsave(file.path(output_dir, "07-quality-metrics.png"), p7, width = 14, height = 6, dpi = 300, bg = "white")

method_views <- bind_rows(
  mutate(embedding, method = "UMAP", X = UMAP1, Y = UMAP2),
  mutate(pca_df, method = "PCA", X = PC1, Y = PC2) %>% select(cell, X, Y, barcode_subset, cluster, method)
)

p8 <- ggplot(method_views, aes(x = X, y = Y, color = barcode_subset)) +
  geom_point(size = 0.9, alpha = 0.58) +
  facet_wrap(~ method, nrow = 1, scales = "free") +
  scale_color_brewer(palette = "Dark2") +
  labs(title = "Real coordinate views for integration diagnostics", x = "Dimension 1", y = "Dimension 2", color = "Subset") +
  theme_sci() +
  theme(legend.position = "bottom")
ggsave(file.path(output_dir, "08-methods-umap.png"), p8, width = 12, height = 6, dpi = 300, bg = "white")

marker_genes <- c("IL7R", "CCR7", "MS4A1", "CD79A", "LYZ", "S100A8", "NKG7", "GNLY", "PPBP", "CST3")
marker_genes <- marker_genes[marker_genes %in% rownames(log_norm)]

marker_expr <- bind_rows(lapply(marker_genes, function(gene) {
  data.frame(cell = colnames(log_norm), gene = gene, expression = as.numeric(log_norm[gene, ]))
})) %>%
  left_join(select(embedding, cell, barcode_subset), by = "cell")

p9 <- marker_expr %>%
  group_by(gene, barcode_subset) %>%
  summarise(expression = mean(expression), .groups = "drop") %>%
  ggplot(aes(x = gene, y = expression, fill = barcode_subset)) +
  geom_col(position = "dodge", color = "black", linewidth = 0.2) +
  scale_fill_brewer(palette = "Dark2") +
  labs(title = "Real marker expression across PBMC 3k barcode subsets", x = "Marker gene", y = "Mean log1p expression", fill = "Subset") +
  theme_sci() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "top")
ggsave(file.path(output_dir, "09-gene-expression.png"), p9, width = 14, height = 8, dpi = 300, bg = "white")

workflow <- data.frame(
  step = 1:7,
  stage = c("Load matrix", "QC", "Normalize", "HVG", "PCA", "UMAP", "Validate"),
  output = c(ncol(counts), ncol(counts), ncol(counts), nrow(log_norm), 30, 2, length(unique(embedding$cell_type)))
)

p10 <- ggplot(workflow, aes(x = step, y = output)) +
  geom_line(color = "#00A087", linewidth = 1.3) +
  geom_point(size = 4, color = "#00A087", fill = "white", shape = 21, stroke = 1.4) +
  geom_text(aes(label = stage), vjust = -1.3, size = 3.4, fontface = "bold") +
  scale_x_continuous(breaks = workflow$step) +
  scale_y_continuous(labels = comma, expand = expansion(mult = c(0.04, 0.22))) +
  labs(title = "Real-data integration preparation workflow", x = "Step", y = "Output count") +
  theme_sci()
ggsave(file.path(output_dir, "10-workflow.png"), p10, width = 11, height = 6, dpi = 300, bg = "white")

message("=== Generated Module 05 real-data figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Output directory: ", normalizePath(output_dir))
sessionInfo()
