#!/usr/bin/env Rscript
# ============================================================================
# BioF3 Module 04: QC, PCA, UMAP and Cell Annotation with Real PBMC 3k Data
# ============================================================================

options(stringsAsFactors = FALSE)

args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/single-cell/sc04_integration_sci.R", mustWork = FALSE)
}

script_dir <- dirname(script_path)
project_root <- if (basename(script_dir) == "scripts" && basename(dirname(script_dir)) == "static") {
  dirname(dirname(script_dir))
} else {
  dirname(script_dir)
}

data_root <- Sys.getenv("BIOF3_DATA_DIR", file.path(path.expand("~"), "biof3-data"))
pbmc_dir <- file.path(data_root, "pbmc3k")
output_dir <- Sys.getenv("BIOF3_OUTPUT_DIR", file.path(project_root, "static", "img", "tutorial", "modules", "module04"))

dir.create(pbmc_dir, recursive = TRUE, showWarnings = FALSE)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

required_packages <- c("Matrix", "ggplot2", "dplyr", "tidyr", "patchwork", "scales", "uwot", "pheatmap", "ggrepel")
for (pkg in required_packages) {
  if (!requireNamespace(pkg, quietly = TRUE)) {
    install.packages(pkg, repos = "https://cloud.r-project.org/")
  }
}

library(Matrix)
library(ggplot2)
library(dplyr)
library(tidyr)
library(patchwork)
library(scales)
library(uwot)
library(pheatmap)
library(ggrepel)

theme_sci <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text = element_text(color = "black"),
      axis.title = element_text(face = "bold", color = "black"),
      plot.title = element_text(face = "bold", hjust = 0.5, color = "black"),
      plot.subtitle = element_text(color = "gray30", hjust = 0.5),
      legend.title = element_text(face = "bold"),
      panel.background = element_rect(fill = "white", color = NA),
      plot.background = element_rect(fill = "white", color = NA),
      strip.background = element_blank(),
      strip.text = element_text(face = "bold", color = "black")
    )
}

pbmc_url <- paste0(
  "https://cf.10xgenomics.com/samples/cell-exp/1.1.0/pbmc3k/",
  "pbmc3k_filtered_gene_bc_matrices.tar.gz"
)
pbmc_tar <- file.path(pbmc_dir, "pbmc3k_filtered_gene_bc_matrices.tar.gz")
matrix_dir <- file.path(pbmc_dir, "filtered_gene_bc_matrices", "hg19")

download_pbmc3k <- function() {
  if (!file.exists(pbmc_tar)) {
    message("Downloading PBMC 3k data from 10x Genomics...")
    download.file(pbmc_url, destfile = pbmc_tar, mode = "wb", quiet = FALSE)
  }
  if (!file.exists(file.path(matrix_dir, "matrix.mtx"))) {
    message("Extracting PBMC 3k archive...")
    utils::untar(pbmc_tar, exdir = pbmc_dir)
  }
}

read_pbmc3k <- function() {
  download_pbmc3k()
  counts <- Matrix::readMM(file.path(matrix_dir, "matrix.mtx"))
  genes <- read.delim(file.path(matrix_dir, "genes.tsv"), header = FALSE, sep = "\t")
  barcodes <- read.delim(file.path(matrix_dir, "barcodes.tsv"), header = FALSE, sep = "\t")
  rownames(counts) <- make.unique(genes[[2]])
  colnames(counts) <- barcodes[[1]]
  as(counts, "CsparseMatrix")
}

message("=== Loading real PBMC 3k matrix ===")
counts <- read_pbmc3k()

mt_genes <- grepl("^MT-", rownames(counts))
n_counts <- Matrix::colSums(counts)
n_features <- Matrix::colSums(counts > 0)
percent_mt <- Matrix::colSums(counts[mt_genes, , drop = FALSE]) / n_counts * 100

qc <- data.frame(
  cell = colnames(counts),
  nCount_RNA = as.numeric(n_counts),
  nFeature_RNA = as.numeric(n_features),
  percent_mt = as.numeric(percent_mt)
) %>%
  mutate(pass_qc = nFeature_RNA > 200 & nFeature_RNA < 2500 & percent_mt < 5)

gene_summary <- data.frame(
  gene = rownames(counts),
  mean_counts = as.numeric(Matrix::rowMeans(counts)),
  variance = as.numeric(Matrix::rowMeans(counts^2) - Matrix::rowMeans(counts)^2),
  detected_cells = as.numeric(Matrix::rowSums(counts > 0))
) %>%
  filter(detected_cells >= 10, mean_counts > 0) %>%
  mutate(dispersion = variance / mean_counts)

hvg <- gene_summary %>%
  filter(!grepl("^MT-|^RPL|^RPS", gene)) %>%
  arrange(desc(dispersion)) %>%
  slice_head(n = 1200)

norm_counts <- t(t(counts[hvg$gene, ]) / n_counts * 10000)
log_norm <- log1p(norm_counts)
scaled <- t(scale(t(as.matrix(log_norm))))
scaled[is.na(scaled)] <- 0

pca <- prcomp(t(scaled), center = FALSE, scale. = FALSE, rank. = 30)
pca_df <- as.data.frame(pca$x[, 1:30])
pca_df$cell <- rownames(pca_df)
pca_df <- left_join(pca_df, qc, by = "cell")

initial_centers <- as.matrix(pca_df[round(seq(1, nrow(pca_df), length.out = 6)), paste0("PC", 1:10)])
km <- kmeans(as.matrix(pca_df[, paste0("PC", 1:10)]), centers = initial_centers, iter.max = 100)
pca_df$cluster <- factor(km$cluster)

umap_coords <- uwot::umap(
  as.matrix(pca_df[, paste0("PC", 1:20)]),
  n_neighbors = 20,
  min_dist = 0.35,
  metric = "cosine",
  n_threads = 1,
  verbose = FALSE,
  ret_model = FALSE
)

embedding <- data.frame(
  cell = pca_df$cell,
  UMAP1 = umap_coords[, 1],
  UMAP2 = umap_coords[, 2],
  cluster = pca_df$cluster
) %>%
  left_join(qc, by = "cell")

marker_sets <- list(
  "CD4 T" = c("IL7R", "CCR7", "LTB"),
  "B cells" = c("MS4A1", "CD79A", "CD74"),
  "Monocytes" = c("LYZ", "S100A8", "S100A9"),
  "NK cells" = c("NKG7", "GNLY", "GZMB"),
  "Platelets" = c("PPBP", "PF4"),
  "Dendritic" = c("FCER1A", "CST3")
)

cluster_marker_score <- bind_rows(lapply(names(marker_sets), function(cell_type) {
  genes <- intersect(marker_sets[[cell_type]], rownames(log_norm))
  if (length(genes) == 0) return(NULL)
  score <- Matrix::colMeans(log_norm[genes, , drop = FALSE])
  data.frame(cell = colnames(log_norm), cell_type = cell_type, score = as.numeric(score))
})) %>%
  left_join(select(embedding, cell, cluster), by = "cell") %>%
  group_by(cluster, cell_type) %>%
  summarise(score = mean(score), .groups = "drop")

cluster_annotation <- cluster_marker_score %>%
  group_by(cluster) %>%
  slice_max(score, n = 1, with_ties = FALSE) %>%
  ungroup()

embedding <- embedding %>%
  left_join(select(cluster_annotation, cluster, cell_type), by = "cluster")

marker_genes <- unique(unlist(marker_sets))
marker_genes <- marker_genes[marker_genes %in% rownames(log_norm)]

get_expression <- function(gene) {
  data.frame(
    cell = colnames(log_norm),
    gene = gene,
    expression = as.numeric(log_norm[gene, ])
  )
}

marker_long <- bind_rows(lapply(marker_genes, get_expression)) %>%
  left_join(select(embedding, cell, cluster, cell_type, UMAP1, UMAP2), by = "cell")

message("=== Generating real-data figures ===")

workflow <- data.frame(
  step = 1:8,
  stage = c("Raw matrix", "QC metrics", "QC filter", "Normalize", "HVG", "PCA", "UMAP", "Annotation"),
  cells_retained = c(
    100,
    100,
    mean(qc$pass_qc) * 100,
    mean(qc$pass_qc) * 100,
    mean(qc$pass_qc) * 100,
    mean(qc$pass_qc) * 100,
    mean(qc$pass_qc) * 100,
    mean(qc$pass_qc) * 100
  )
)

p1 <- ggplot(workflow, aes(x = step, y = cells_retained)) +
  geom_line(color = "#00A087", linewidth = 1.3) +
  geom_point(size = 4, color = "#00A087", fill = "white", shape = 21, stroke = 1.5) +
  geom_text(aes(label = stage), vjust = -1.3, size = 3.1, fontface = "bold") +
  scale_x_continuous(breaks = workflow$step) +
  scale_y_continuous(limits = c(0, 108), labels = label_percent(scale = 1)) +
  labs(title = "PBMC 3k single-cell analysis workflow", x = "Analysis step", y = "Cells retained") +
  theme_sci()
ggsave(file.path(output_dir, "01-workflow.png"), p1, width = 11, height = 6, dpi = 300, bg = "white")

qc_long <- qc %>%
  select(nFeature_RNA, nCount_RNA, percent_mt) %>%
  pivot_longer(everything(), names_to = "metric", values_to = "value") %>%
  mutate(metric = factor(metric, levels = c("nFeature_RNA", "nCount_RNA", "percent_mt")))

p2 <- ggplot(qc_long, aes(x = metric, y = value, fill = metric)) +
  geom_violin(trim = FALSE, scale = "width", color = "black", linewidth = 0.25) +
  geom_boxplot(width = 0.12, fill = "white", outlier.shape = NA, linewidth = 0.4) +
  scale_fill_manual(values = c("#3C5488", "#00A087", "#E64B35")) +
  scale_y_continuous(labels = comma) +
  labs(title = "PBMC 3k quality control metrics", x = "Metric", y = "Value") +
  theme_sci() +
  theme(legend.position = "none", axis.text.x = element_text(angle = 25, hjust = 1))
ggsave(file.path(output_dir, "02-qc-violin.png"), p2, width = 9, height = 6, dpi = 300, bg = "white")

p3a <- ggplot(qc, aes(x = nCount_RNA, y = nFeature_RNA, color = pass_qc)) +
  geom_point(size = 1.2, alpha = 0.65) +
  geom_hline(yintercept = c(200, 2500), linetype = "dashed", color = "red", linewidth = 0.7) +
  scale_x_continuous(labels = comma) +
  scale_y_continuous(labels = comma) +
  scale_color_manual(values = c("FALSE" = "#E64B35", "TRUE" = "#00A087"), labels = c("Filtered", "Retained")) +
  labs(x = "Total UMI counts", y = "Detected genes", color = "QC status") +
  theme_sci()

p3b <- ggplot(qc, aes(x = nCount_RNA, y = percent_mt, color = pass_qc)) +
  geom_point(size = 1.2, alpha = 0.65) +
  geom_hline(yintercept = 5, linetype = "dashed", color = "red", linewidth = 0.7) +
  scale_x_continuous(labels = comma) +
  scale_color_manual(values = c("FALSE" = "#E64B35", "TRUE" = "#00A087"), labels = c("Filtered", "Retained")) +
  labs(x = "Total UMI counts", y = "Mitochondrial %", color = "QC status") +
  theme_sci()

p3 <- (p3a | p3b) + plot_annotation(title = "Quality control filtering on real PBMC 3k cells")
ggsave(file.path(output_dir, "03-qc-scatter.png"), p3, width = 14, height = 6, dpi = 300, bg = "white")

top_hvg <- hvg %>% slice_head(n = 10)
p4 <- ggplot(gene_summary, aes(x = mean_counts, y = dispersion, color = gene %in% hvg$gene)) +
  geom_point(size = 1.2, alpha = 0.45) +
  geom_text_repel(data = top_hvg, aes(label = gene), size = 3, fontface = "bold", max.overlaps = 20) +
  scale_x_log10(labels = label_number()) +
  scale_y_log10(labels = label_number()) +
  scale_color_manual(values = c("FALSE" = "gray70", "TRUE" = "#E64B35"), labels = c("Other genes", "Highly variable")) +
  labs(title = "Highly variable genes in PBMC 3k", x = "Mean expression", y = "Dispersion", color = "Gene type") +
  theme_sci()
ggsave(file.path(output_dir, "04-hvg.png"), p4, width = 10, height = 7, dpi = 300, bg = "white")

pca_variance <- data.frame(
  PC = seq_along(pca$sdev),
  variance = pca$sdev^2 / sum(pca$sdev^2) * 100
)
p5 <- ggplot(pca_variance, aes(x = PC, y = variance)) +
  geom_line(color = "#4DBBD5", linewidth = 1.3) +
  geom_point(size = 1.8, color = "#4DBBD5") +
  geom_vline(xintercept = 20, linetype = "dashed", color = "red", linewidth = 0.8) +
  annotate("text", x = 25, y = max(pca_variance$variance) * 0.75, label = "First 20 PCs used for UMAP", color = "red", fontface = "bold") +
  labs(title = "PCA variance explained", x = "Principal component", y = "Variance explained (%)") +
  theme_sci()
ggsave(file.path(output_dir, "05-pca-elbow.png"), p5, width = 10, height = 6, dpi = 300, bg = "white")

p6 <- ggplot(pca_df, aes(x = PC1, y = PC2, color = cluster)) +
  geom_point(size = 1.3, alpha = 0.72) +
  scale_color_brewer(palette = "Dark2") +
  labs(title = "PCA of real PBMC 3k cells", x = "PC1", y = "PC2", color = "Cluster") +
  theme_sci()
ggsave(file.path(output_dir, "06-pca-plot.png"), p6, width = 9, height = 7, dpi = 300, bg = "white")

cluster_centers <- embedding %>%
  group_by(cluster) %>%
  summarise(UMAP1 = median(UMAP1), UMAP2 = median(UMAP2), .groups = "drop")

p7 <- ggplot(embedding, aes(x = UMAP1, y = UMAP2, color = cluster)) +
  geom_point(size = 1.2, alpha = 0.72) +
  geom_text(data = cluster_centers, aes(label = cluster), size = 5, fontface = "bold", color = "black") +
  scale_color_brewer(palette = "Dark2") +
  labs(title = "UMAP clustering of PBMC 3k", x = "UMAP1", y = "UMAP2", color = "Cluster") +
  theme_sci()
ggsave(file.path(output_dir, "07-umap-clusters.png"), p7, width = 10, height = 7, dpi = 300, bg = "white")

plot_marker <- function(gene, label) {
  values <- data.frame(cell = colnames(log_norm), expression = as.numeric(log_norm[gene, ]))
  left_join(embedding, values, by = "cell") %>%
    ggplot(aes(x = UMAP1, y = UMAP2, color = expression)) +
    geom_point(size = 1.1, alpha = 0.8) +
    scale_color_viridis_c(option = "plasma") +
    labs(title = label, x = "UMAP1", y = "UMAP2", color = "log1p") +
    theme_sci() +
    theme(legend.position = "right")
}

p8 <- (plot_marker("IL7R", "IL7R (T cells)") | plot_marker("S100A8", "S100A8 (Monocytes)")) /
  (plot_marker("MS4A1", "MS4A1 (B cells)") | plot_marker("NKG7", "NKG7 (NK cells)")) +
  plot_annotation(title = "Real marker gene expression on PBMC 3k UMAP")
ggsave(file.path(output_dir, "08-umap-genes.png"), p8, width = 14, height = 12, dpi = 300, bg = "white")

type_centers <- embedding %>%
  group_by(cell_type) %>%
  summarise(UMAP1 = median(UMAP1), UMAP2 = median(UMAP2), .groups = "drop")

p9 <- ggplot(embedding, aes(x = UMAP1, y = UMAP2, color = cell_type)) +
  geom_point(size = 1.2, alpha = 0.72) +
  geom_text_repel(data = type_centers, aes(label = cell_type), color = "black", fontface = "bold", size = 4, max.overlaps = Inf) +
  scale_color_brewer(palette = "Set2") +
  labs(title = "Marker-based cell type annotation", x = "UMAP1", y = "UMAP2", color = "Cell type") +
  theme_sci()
ggsave(file.path(output_dir, "09-cell-types.png"), p9, width = 11, height = 7, dpi = 300, bg = "white")

heatmap_data <- marker_long %>%
  group_by(gene, cell_type) %>%
  summarise(expression = mean(expression), .groups = "drop") %>%
  pivot_wider(names_from = cell_type, values_from = expression, values_fill = 0)

heat_mat <- as.matrix(heatmap_data[, -1])
rownames(heat_mat) <- heatmap_data$gene
heat_mat <- t(scale(t(heat_mat)))
heat_mat[is.na(heat_mat)] <- 0

png(file.path(output_dir, "10-marker-heatmap.png"), width = 10, height = 8, units = "in", res = 300, bg = "white")
pheatmap(
  heat_mat,
  color = colorRampPalette(c("#3C5488", "white", "#E64B35"))(100),
  cluster_rows = TRUE,
  cluster_cols = TRUE,
  fontsize = 11,
  border_color = "gray85",
  main = "PBMC 3k marker gene expression by annotated cell type"
)
dev.off()

message("=== Generated Module 04 real-data figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Output directory: ", normalizePath(output_dir))
sessionInfo()
