#!/usr/bin/env Rscript
# ============================================================================
# BioF3 Module 01: Real PBMC 3k Data Access and QC Figures
# ============================================================================
#
# This script uses the public 10x Genomics PBMC 3k dataset.
# It downloads the filtered 10x matrix if needed, reads the real count matrix,
# calculates basic QC metrics, and generates figures used by Module 01.
#
# Data source:
# https://cf.10xgenomics.com/samples/cell-exp/1.1.0/pbmc3k/
#   pbmc3k_filtered_gene_bc_matrices.tar.gz
#
# Usage:
#   Rscript scripts/module01_complete_sci.R
#
# Optional environment variables:
#   BIOF3_DATA_DIR=/path/to/data
#   BIOF3_OUTPUT_DIR=/path/to/static/img/tutorial/single-cell/module01
#
# ============================================================================

options(stringsAsFactors = FALSE)

args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/single-cell/sc01_data_sci.R", mustWork = FALSE)
}

script_dir <- dirname(script_path)
project_root <- if (basename(script_dir) == "scripts" && basename(dirname(script_dir)) == "static") {
  dirname(dirname(script_dir))
} else {
  dirname(script_dir)
}

data_root <- Sys.getenv("BIOF3_DATA_DIR", file.path(path.expand("~"), "biof3-data"))
pbmc_dir <- file.path(data_root, "pbmc3k")
output_dir <- Sys.getenv(
  "BIOF3_OUTPUT_DIR",
  file.path(project_root, "static", "img", "tutorial", "modules", "module01")
)

dir.create(pbmc_dir, recursive = TRUE, showWarnings = FALSE)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

pbmc_url <- paste0(
  "https://cf.10xgenomics.com/samples/cell-exp/1.1.0/pbmc3k/",
  "pbmc3k_filtered_gene_bc_matrices.tar.gz"
)
pbmc_tar <- file.path(pbmc_dir, "pbmc3k_filtered_gene_bc_matrices.tar.gz")
matrix_dir <- file.path(pbmc_dir, "filtered_gene_bc_matrices", "hg19")
matrix_file <- file.path(matrix_dir, "matrix.mtx")
genes_file <- file.path(matrix_dir, "genes.tsv")
barcodes_file <- file.path(matrix_dir, "barcodes.tsv")

required_packages <- c("Matrix", "ggplot2", "dplyr", "pheatmap", "patchwork", "scales")

for (pkg in required_packages) {
  if (!requireNamespace(pkg, quietly = TRUE)) {
    install.packages(pkg, repos = "https://cloud.r-project.org/")
  }
}

library(Matrix)
library(ggplot2)
library(dplyr)
library(pheatmap)
library(patchwork)
library(scales)

theme_biof3 <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text = element_text(color = "#111827"),
      axis.title = element_text(color = "#111827", face = "bold"),
      plot.title = element_text(color = "#12342f", face = "bold", hjust = 0),
      plot.subtitle = element_text(color = "#526760", hjust = 0),
      legend.title = element_text(face = "bold"),
      legend.position = "right"
    )
}

biof3_colors <- c(
  green = "#0f766e",
  mint = "#43d1ae",
  blue = "#2563eb",
  amber = "#f59e0b",
  red = "#dc2626",
  slate = "#475569",
  pale = "#e8f5ef"
)

download_pbmc3k <- function() {
  if (!file.exists(pbmc_tar)) {
    message("Downloading PBMC 3k data from 10x Genomics...")
    download.file(pbmc_url, destfile = pbmc_tar, mode = "wb", quiet = FALSE)
  } else {
    message("Found existing archive: ", pbmc_tar)
  }

  if (!file.exists(matrix_file)) {
    message("Extracting PBMC 3k archive...")
    utils::untar(pbmc_tar, exdir = pbmc_dir)
  } else {
    message("Found extracted matrix: ", matrix_file)
  }
}

read_pbmc3k <- function() {
  download_pbmc3k()

  if (!file.exists(matrix_file) || !file.exists(genes_file) || !file.exists(barcodes_file)) {
    stop("PBMC 3k matrix files are incomplete under: ", matrix_dir)
  }

  counts <- Matrix::readMM(matrix_file)
  genes <- read.delim(genes_file, header = FALSE, sep = "\t")
  barcodes <- read.delim(barcodes_file, header = FALSE, sep = "\t")

  if (ncol(genes) >= 2) {
    gene_ids <- genes[[1]]
    gene_symbols <- make.unique(genes[[2]])
  } else {
    gene_ids <- genes[[1]]
    gene_symbols <- make.unique(genes[[1]])
  }

  rownames(counts) <- gene_symbols
  colnames(counts) <- barcodes[[1]]

  list(
    counts = as(counts, "CsparseMatrix"),
    gene_ids = gene_ids,
    gene_symbols = gene_symbols,
    barcodes = barcodes[[1]]
  )
}

message("=== Loading real PBMC 3k data ===")
pbmc <- read_pbmc3k()
counts <- pbmc$counts

message("Genes: ", nrow(counts))
message("Cells: ", ncol(counts))

mt_genes <- grepl("^MT-", rownames(counts))
n_counts <- Matrix::colSums(counts)
n_features <- Matrix::colSums(counts > 0)
percent_mt <- if (any(mt_genes)) {
  Matrix::colSums(counts[mt_genes, , drop = FALSE]) / n_counts * 100
} else {
  rep(NA_real_, ncol(counts))
}

qc <- data.frame(
  cell = colnames(counts),
  nCount_RNA = as.numeric(n_counts),
  nFeature_RNA = as.numeric(n_features),
  percent_mt = as.numeric(percent_mt)
)

gene_total_counts <- Matrix::rowSums(counts)
gene_detected_cells <- Matrix::rowSums(counts > 0)
gene_mean_counts <- gene_total_counts / ncol(counts)

gene_summary <- data.frame(
  gene = rownames(counts),
  total_counts = as.numeric(gene_total_counts),
  detected_cells = as.numeric(gene_detected_cells),
  mean_counts = as.numeric(gene_mean_counts)
)

message("Median counts per cell: ", round(median(qc$nCount_RNA)))
message("Median genes per cell: ", round(median(qc$nFeature_RNA)))
message("Median mitochondrial percent: ", round(median(qc$percent_mt, na.rm = TRUE), 2), "%")

# ============================================================================
# Figure 1: Top expressed genes in PBMC 3k
# ============================================================================

message("Generating Figure 1: Top expressed genes")

top_genes <- gene_summary %>%
  filter(!grepl("^MT-|^RPL|^RPS", gene)) %>%
  arrange(desc(total_counts)) %>%
  slice_head(n = 12) %>%
  mutate(gene = factor(gene, levels = rev(gene)))

p1 <- ggplot(top_genes, aes(x = gene, y = total_counts)) +
  geom_col(fill = biof3_colors["green"], width = 0.72) +
  coord_flip() +
  scale_y_continuous(labels = comma) +
  labs(
    title = "Top expressed genes in PBMC 3k",
    subtitle = "Real 10x Genomics filtered gene-barcode matrix",
    x = NULL,
    y = "Total UMI counts"
  ) +
  theme_biof3()

ggsave(file.path(output_dir, "01-gene-expression-bar.png"), p1, width = 8, height = 5, dpi = 300, bg = "white")

# ============================================================================
# Figure 2: Total counts per cell
# ============================================================================

message("Generating Figure 2: Total counts per cell")

p2 <- ggplot(qc, aes(x = nCount_RNA)) +
  geom_histogram(bins = 50, fill = biof3_colors["green"], color = "white", alpha = 0.88) +
  geom_vline(xintercept = median(qc$nCount_RNA), linetype = "dashed", color = biof3_colors["amber"], linewidth = 0.8) +
  scale_x_continuous(labels = comma) +
  labs(
    title = "PBMC 3k library size distribution",
    subtitle = paste0("Median total counts = ", comma(round(median(qc$nCount_RNA)))),
    x = "Total UMI counts per cell",
    y = "Number of cells"
  ) +
  theme_biof3()

ggsave(file.path(output_dir, "02-cell-counts-distribution.png"), p2, width = 8, height = 5, dpi = 300, bg = "white")

# ============================================================================
# Figure 3: Mean expression per gene
# ============================================================================

message("Generating Figure 3: Mean expression per gene")

expressed_genes <- gene_summary %>%
  filter(detected_cells > 0)

p3 <- ggplot(expressed_genes, aes(x = mean_counts)) +
  geom_histogram(bins = 60, fill = biof3_colors["blue"], color = "white", alpha = 0.84) +
  scale_x_log10(labels = label_number(accuracy = 0.001)) +
  labs(
    title = "PBMC 3k mean gene expression",
    subtitle = paste0(nrow(expressed_genes), " genes detected in at least one cell"),
    x = "Mean UMI counts per cell (log10 scale)",
    y = "Number of genes"
  ) +
  theme_biof3()

ggsave(file.path(output_dir, "03-gene-mean-distribution.png"), p3, width = 8, height = 5, dpi = 300, bg = "white")

# ============================================================================
# Figure 4: Real expression matrix heatmap
# ============================================================================

message("Generating Figure 4: Expression matrix heatmap")

candidate_genes <- gene_summary %>%
  filter(detected_cells >= 100, !grepl("^MT-", gene)) %>%
  arrange(desc(total_counts)) %>%
  slice_head(n = 30) %>%
  pull(gene)

cell_order <- order(qc$nCount_RNA)
selected_cells <- colnames(counts)[cell_order[round(seq(1, length(cell_order), length.out = 80))]]

heat_counts <- as.matrix(counts[candidate_genes, selected_cells, drop = FALSE])
cell_totals <- qc$nCount_RNA[match(selected_cells, qc$cell)]
heat_norm <- t(t(heat_counts) / cell_totals * 10000)
heat_log <- log1p(heat_norm)

png(file.path(output_dir, "04-expression-matrix-heatmap.png"), width = 8, height = 6, units = "in", res = 300)
pheatmap(
  heat_log,
  color = colorRampPalette(c("#f8fafc", "#43d1ae", "#0f172a"))(100),
  cluster_rows = TRUE,
  cluster_cols = TRUE,
  show_colnames = FALSE,
  fontsize = 8,
  fontsize_row = 7,
  main = "PBMC 3k expression matrix (log1p CPM)"
)
dev.off()

# ============================================================================
# Figure 5: QC scatter plot
# ============================================================================

message("Generating Figure 5: QC scatter plot")

p5 <- ggplot(qc, aes(x = nCount_RNA, y = nFeature_RNA, color = percent_mt)) +
  geom_point(alpha = 0.72, size = 1.25) +
  scale_x_continuous(labels = comma) +
  scale_y_continuous(labels = comma) +
  scale_color_gradient(low = biof3_colors["green"], high = biof3_colors["red"], na.value = biof3_colors["slate"]) +
  labs(
    title = "PBMC 3k quality control metrics",
    subtitle = "Each point is one cell barcode from the filtered 10x matrix",
    x = "Total UMI counts",
    y = "Detected genes",
    color = "Mitochondrial %"
  ) +
  theme_biof3()

ggsave(file.path(output_dir, "05-qc-scatter.png"), p5, width = 8, height = 5, dpi = 300, bg = "white")

# ============================================================================
# Figure 6: Public data source guide (conceptual)
# ============================================================================

message("Generating Figure 6: Public data source guide")

source_matrix <- data.frame(
  source = rep(c("GEO", "SRA", "CELLxGENE", "HCA", "10x Datasets"), each = 4),
  feature = rep(c("Processed matrix", "Raw reads", "Cell annotation", "Interactive view"), times = 5),
  available = c(
    1, 0.5, 0.5, 0,
    0, 1, 0, 0,
    1, 0, 1, 1,
    1, 0.5, 1, 0.5,
    1, 0.5, 0.5, 0.5
  )
)

p6 <- ggplot(source_matrix, aes(x = feature, y = source, fill = available)) +
  geom_tile(color = "white", linewidth = 0.8) +
  scale_fill_gradientn(
    colors = c("#f1f5f9", "#a7f3d0", "#0f766e"),
    breaks = c(0, 0.5, 1),
    labels = c("Limited", "Sometimes", "Common")
  ) +
  labs(
    title = "How public data sources differ",
    subtitle = "Conceptual guide; always check the original dataset page",
    x = NULL,
    y = NULL,
    fill = "Availability"
  ) +
  theme_biof3() +
  theme(axis.text.x = element_text(angle = 30, hjust = 1))

ggsave(file.path(output_dir, "06-database-comparison.png"), p6, width = 8, height = 5, dpi = 300, bg = "white")

# ============================================================================
# Figure 7: Reproducible data workflow (conceptual)
# ============================================================================

message("Generating Figure 7: Reproducible data workflow")

workflow <- data.frame(
  step = factor(
    c("Choose dataset", "Record source", "Download data", "Inspect matrix", "Save outputs"),
    levels = c("Choose dataset", "Record source", "Download data", "Inspect matrix", "Save outputs")
  ),
  order = 1:5
)

p7 <- ggplot(workflow, aes(x = order, y = 1)) +
  annotate("segment", x = 1, xend = 5, y = 1, yend = 1, color = "#cbd5e1", linewidth = 1.2) +
  geom_point(size = 12, color = biof3_colors["pale"]) +
  geom_point(size = 9, color = biof3_colors["green"]) +
  geom_text(aes(label = order), color = "white", fontface = "bold", size = 4.5) +
  geom_text(aes(label = step), y = 0.76, size = 3.8, color = "#12342f") +
  scale_x_continuous(limits = c(0.6, 5.4), breaks = NULL) +
  scale_y_continuous(limits = c(0.62, 1.2), breaks = NULL) +
  labs(
    title = "A reproducible public-data workflow",
    subtitle = "Use real data, record provenance, and keep outputs traceable",
    x = NULL,
    y = NULL
  ) +
  theme_void(base_size = 12) +
  theme(
    plot.title = element_text(color = "#12342f", face = "bold"),
    plot.subtitle = element_text(color = "#526760"),
    plot.margin = margin(18, 18, 18, 18)
  )

ggsave(file.path(output_dir, "07-workflow.png"), p7, width = 9, height = 4.5, dpi = 300, bg = "white")

# ============================================================================
# Figure 8: Combined real QC metrics
# ============================================================================

message("Generating Figure 8: Combined QC metrics")

p8a <- ggplot(qc, aes(x = nCount_RNA)) +
  geom_histogram(bins = 45, fill = biof3_colors["green"], color = "white", alpha = 0.86) +
  scale_x_continuous(labels = comma) +
  labs(x = "Total UMI counts", y = "Cells") +
  theme_biof3(10)

p8b <- ggplot(qc, aes(x = nFeature_RNA)) +
  geom_histogram(bins = 45, fill = biof3_colors["blue"], color = "white", alpha = 0.82) +
  scale_x_continuous(labels = comma) +
  labs(x = "Detected genes", y = "Cells") +
  theme_biof3(10)

p8c <- ggplot(qc, aes(x = percent_mt)) +
  geom_histogram(bins = 45, fill = biof3_colors["amber"], color = "white", alpha = 0.82) +
  labs(x = "Mitochondrial %", y = "Cells") +
  theme_biof3(10)

p8 <- (p8a | p8b | p8c) +
  plot_annotation(
    title = "PBMC 3k real QC metrics",
    subtitle = "Calculated from the 10x Genomics filtered gene-barcode matrix",
    theme = theme(
      plot.title = element_text(color = "#12342f", face = "bold", size = 14),
      plot.subtitle = element_text(color = "#526760", size = 11)
    )
  )

ggsave(file.path(output_dir, "08-qc-combined.png"), p8, width = 12, height = 4.2, dpi = 300, bg = "white")

message("=== Generated Module 01 figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))

message("Output directory: ", normalizePath(output_dir))
message("Data directory: ", normalizePath(pbmc_dir))
message("Done.")

sessionInfo()
