#!/usr/bin/env Rscript
# =============================================================================
# BioF3 Bulk RNA-seq Module 06: Batch effects — ComBat / SVA / batch-in-design
# =============================================================================
#
# This companion script for bulk RNA-seq module 06:
#   1. Uses the bladderbatch dataset (the canonical batch-effect demo):
#      - 57 samples, 5 batches, status = Normal / Biopsy / Cancer
#      - status is heavily confounded with batch, which is the ugly case
#   2. Shows three workflows on the same data:
#      - Ignore the batch (naive)
#      - Correct with ComBat
#      - Estimate surrogate variables with SVA and include in the design
#   3. Also shows DESeq2 with airway + cell line as batch variable, the
#      clean RNA-seq case
#   4. Produces six figures:
#      - PCA before / after ComBat (bladderbatch)
#      - Hierarchical clustering before / after ComBat
#      - SVA surrogate variables vs batch (heatmap of correlation)
#      - Number of DE genes at padj < 0.05 for 3 strategies
#      - Volcano overlay: same two-group comparison, batch-aware vs not
#      - airway: impact of including cell as a blocking factor
#
# Data sources:
#   Bioconductor bladderbatch (ExpressionSet, microarray — batch demo)
#   Bioconductor airway (RNA-seq)
#
# Usage:
#   Rscript scripts/bulk06_batch_sci.R
#
# Optional env vars:
#   BIOF3_OUTPUT_DIR=/path/to/static/img/tutorial/bulk06
#
# Dependencies:
#   sva, limma, DESeq2, bladderbatch, airway, ggplot2, dplyr, patchwork,
#   pheatmap, RColorBrewer
# =============================================================================

options(stringsAsFactors = FALSE)
set.seed(42)

# ---- paths ---------------------------------------------------------------
args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/bulk06_batch_sci.R", mustWork = FALSE)
}
script_dir <- dirname(script_path)
project_root <- if (basename(script_dir) == "scripts" && basename(dirname(script_dir)) == "static") {
  dirname(dirname(script_dir))
} else {
  dirname(script_dir)
}
output_dir <- Sys.getenv(
  "BIOF3_OUTPUT_DIR",
  file.path(project_root, "static", "img", "tutorial", "bulk06")
)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

# ---- dependencies --------------------------------------------------------
suppressPackageStartupMessages({
  library(sva)
  library(limma)
  library(DESeq2)
  library(bladderbatch)
  library(airway)
  library(ggplot2)
  library(dplyr)
  library(patchwork)
  library(pheatmap)
  library(RColorBrewer)
  library(Biobase)
})

biof3_colors <- c(
  green = "#0f766e", mint  = "#43d1ae", blue  = "#2563eb",
  amber = "#f59e0b", red   = "#dc2626", slate = "#475569",
  violet = "#7c3aed", pink = "#ec4899"
)

theme_biof3 <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text  = element_text(color = "#111827"),
      axis.title = element_text(color = "#111827", face = "bold"),
      plot.title = element_text(color = "#12342f", face = "bold", hjust = 0),
      plot.subtitle = element_text(color = "#526760", hjust = 0),
      legend.title = element_text(face = "bold"),
      legend.position = "right",
      strip.background = element_blank(),
      strip.text = element_text(color = "#12342f", face = "bold")
    )
}

save_plot <- function(p, name, width = 8, height = 5) {
  ggsave(file.path(output_dir, name), p, width = width, height = height, dpi = 300, bg = "white")
}

# =============================================================================
# Part A: bladderbatch (microarray, heavy confounding)
# =============================================================================
message("Loading bladderbatch ...")
data("bladderdata")
pheno <- pData(bladderEset)
edata <- exprs(bladderEset)
batch <- as.numeric(pheno$batch)
mod   <- model.matrix(~ as.factor(cancer), data = pheno)

# ---- ComBat --------------------------------------------------------------
message("Running ComBat ...")
combat_edata <- ComBat(dat = edata, batch = batch, mod = NULL, par.prior = TRUE, prior.plots = FALSE)

pca_raw <- prcomp(t(edata), scale. = FALSE)
pca_cb  <- prcomp(t(combat_edata), scale. = FALSE)

pca_df_raw <- data.frame(
  PC1 = pca_raw$x[, 1], PC2 = pca_raw$x[, 2],
  batch = factor(pheno$batch), cancer = pheno$cancer,
  stage = "Raw expression"
)
pca_df_cb <- data.frame(
  PC1 = pca_cb$x[, 1], PC2 = pca_cb$x[, 2],
  batch = factor(pheno$batch), cancer = pheno$cancer,
  stage = "After ComBat"
)
pca_df <- bind_rows(pca_df_raw, pca_df_cb) |>
  mutate(stage = factor(stage, levels = c("Raw expression", "After ComBat")))

p1 <- ggplot(pca_df, aes(x = PC1, y = PC2, color = batch, shape = cancer)) +
  geom_point(size = 3.5, alpha = 0.9) +
  facet_wrap(~ stage, scales = "free") +
  scale_color_manual(values = unname(biof3_colors)[1:5]) +
  labs(title = "PCA before vs after ComBat (bladderbatch)",
       subtitle = "Raw: PC1 tracks batch; After ComBat: samples regroup by biology (cancer status)",
       x = "PC1", y = "PC2", color = "Batch", shape = "Cancer") +
  theme_biof3()
save_plot(p1, "01-combat-pca.png", width = 13, height = 5.5)

# ---- figure 2: sample-to-sample distance before / after ComBat ----------
message("Figure 2: sample distance heatmap")
make_dist_hm <- function(mat, title) {
  d <- dist(t(mat))
  mat_d <- as.matrix(d)
  # use sample index as unique ID; annotation carries the biological info
  rownames(mat_d) <- sprintf("s%02d", seq_len(ncol(mat)))
  colnames(mat_d) <- rownames(mat_d)
  anno <- data.frame(
    cancer = pheno$cancer,
    batch  = factor(pheno$batch),
    row.names = rownames(mat_d)
  )
  anno_colors <- list(
    cancer = c(Normal = biof3_colors[["slate"]],
               Biopsy = biof3_colors[["amber"]],
               Cancer = biof3_colors[["red"]]),
    batch  = setNames(unname(biof3_colors)[1:5], as.character(1:5))
  )
  pheatmap(
    mat_d, annotation_row = anno, annotation_colors = anno_colors,
    show_rownames = FALSE, show_colnames = FALSE,
    col = colorRampPalette(rev(brewer.pal(9, "Blues")))(255),
    main = title, silent = TRUE
  )
}
library(grid)
hm1 <- make_dist_hm(edata, "Before ComBat")
hm2 <- make_dist_hm(combat_edata, "After ComBat")
png(file.path(output_dir, "02-combat-distance.png"),
    width = 14, height = 6, units = "in", res = 300, bg = "white")
grid.newpage()
pushViewport(viewport(layout = grid.layout(1, 2)))
print(hm1, vp = viewport(layout.pos.row = 1, layout.pos.col = 1))
print(hm2, vp = viewport(layout.pos.row = 1, layout.pos.col = 2))
dev.off()

# ---- SVA on bladderbatch ------------------------------------------------
message("Running SVA ...")
mod0 <- model.matrix(~ 1, data = pheno)
n_sv <- num.sv(edata, mod, method = "leek")
message("SVA suggests n.sv = ", n_sv)
# sometimes n_sv is 0 when biological signal is weak; force >= 2 for demo
n_sv <- max(n_sv, 2)
svobj <- sva(edata, mod, mod0, n.sv = n_sv)
message("SVs estimated: ", svobj$n.sv)

sv_df <- as.data.frame(svobj$sv)
colnames(sv_df) <- paste0("SV", seq_len(svobj$n.sv))
sv_df$batch <- factor(pheno$batch)
sv_df$cancer <- pheno$cancer

# correlation between SVs and batch dummies
batch_mm <- model.matrix(~ factor(batch) - 1, data = pheno)
colnames(batch_mm) <- paste0("batch", 1:ncol(batch_mm))
cor_mat <- cor(svobj$sv, batch_mm)
rownames(cor_mat) <- colnames(sv_df)[1:svobj$n.sv]

p3a <- as.data.frame(cor_mat) |>
  tibble::rownames_to_column("SV") |>
  tidyr::pivot_longer(-SV, names_to = "batch", values_to = "r") |>
  ggplot(aes(x = batch, y = SV, fill = r)) +
  geom_tile() +
  geom_text(aes(label = sprintf("%.2f", r)), size = 4, color = "#111827") +
  scale_fill_gradient2(low = biof3_colors[["blue"]], mid = "white",
                       high = biof3_colors[["red"]], midpoint = 0,
                       name = "Pearson r", limits = c(-1, 1)) +
  labs(title = "SVs vs known batches",
       subtitle = "High |r| means that SV captures (at least part of) that batch",
       x = NULL, y = NULL) +
  theme_biof3()

p3b <- ggplot(sv_df, aes(x = SV1, y = SV2, color = batch, shape = cancer)) +
  geom_point(size = 3.5, alpha = 0.9) +
  scale_color_manual(values = unname(biof3_colors)[1:5]) +
  labs(title = "First two surrogate variables",
       subtitle = "Batches separate along SV1 / SV2, recovering the hidden structure",
       x = "SV1", y = "SV2", color = "Batch", shape = "Cancer") +
  theme_biof3()

p3 <- (p3a | p3b) + plot_annotation(
  title = "SVA recovers unobserved batch-like structure",
  theme = theme(plot.title = element_text(face = "bold", color = "#12342f"))
)
save_plot(p3, "03-sva.png", width = 14, height = 5.5)

# ---- figure 4: number of DE genes with 3 strategies ---------------------
message("Figure 4: three strategies comparison on bladderbatch (Normal vs Cancer)")
# keep only Normal + Cancer for a clean 2-group comparison
keep <- pheno$cancer %in% c("Normal", "Cancer")
pheno_sub <- pheno[keep, ]
pheno_sub$cancer <- factor(pheno_sub$cancer, levels = c("Normal", "Cancer"))
edata_sub <- edata[, keep]

# A) naive: no batch correction
design_naive <- model.matrix(~ cancer, data = pheno_sub)
fit_naive <- lmFit(edata_sub, design_naive) |> eBayes()
tt_naive <- topTable(fit_naive, coef = "cancerCancer", number = Inf, sort.by = "none")

# B) include batch in design
design_batch <- model.matrix(~ batch + cancer, data = pheno_sub)
fit_batch <- lmFit(edata_sub, design_batch) |> eBayes()
tt_batch <- topTable(fit_batch, coef = "cancerCancer", number = Inf, sort.by = "none")

# C) SVA + include SVs
mod_sub  <- model.matrix(~ cancer, data = pheno_sub)
mod0_sub <- model.matrix(~ 1, data = pheno_sub)
svobj_sub <- sva(edata_sub, mod_sub, mod0_sub, n.sv = max(num.sv(edata_sub, mod_sub, method="leek"), 2))
design_sva <- cbind(mod_sub, svobj_sub$sv)
colnames(design_sva) <- c(colnames(mod_sub), paste0("SV", seq_len(svobj_sub$n.sv)))
fit_sva <- lmFit(edata_sub, design_sva) |> eBayes()
tt_sva <- topTable(fit_sva, coef = "cancerCancer", number = Inf, sort.by = "none")

summary_df <- data.frame(
  strategy = c("Naive", "Batch in design", "SVA"),
  deg_padj05 = c(
    sum(tt_naive$adj.P.Val < 0.05, na.rm = TRUE),
    sum(tt_batch$adj.P.Val < 0.05, na.rm = TRUE),
    sum(tt_sva$adj.P.Val < 0.05,   na.rm = TRUE)
  )
)
summary_df$strategy <- factor(summary_df$strategy, levels = summary_df$strategy)

p4 <- ggplot(summary_df, aes(x = strategy, y = deg_padj05, fill = strategy)) +
  geom_col(width = 0.6) +
  geom_text(aes(label = deg_padj05), vjust = -0.4, size = 5) +
  scale_fill_manual(values = c(Naive = biof3_colors[["slate"]],
                               `Batch in design` = biof3_colors[["blue"]],
                               SVA = biof3_colors[["green"]]),
                    guide = "none") +
  labs(title = "DE genes (padj < 0.05) for 3 batch strategies",
       subtitle = "bladderbatch Normal vs Cancer; batch-aware methods recover more true DE genes",
       x = NULL, y = "Number of DE genes") +
  theme_biof3()
save_plot(p4, "04-strategy-deg.png", width = 9, height = 5.5)

# ---- figure 5: overlap of DE genes ---------------------------------------
message("Figure 5: DE gene overlap")
de_naive <- rownames(tt_naive)[tt_naive$adj.P.Val < 0.05]
de_batch <- rownames(tt_batch)[tt_batch$adj.P.Val < 0.05]
de_sva   <- rownames(tt_sva)[tt_sva$adj.P.Val < 0.05]
de_list <- list(Naive = de_naive, `Batch in design` = de_batch, SVA = de_sva)

# 3-way venn-like: compute counts
all_de <- unique(unlist(de_list))
in_mat <- sapply(de_list, function(x) all_de %in% x)
pattern <- apply(in_mat, 1, function(x) paste(names(de_list)[x], collapse = " + "))
counts <- sort(table(pattern), decreasing = TRUE)
pat_df <- data.frame(pattern = names(counts), n = as.integer(counts)) |>
  mutate(pattern = factor(pattern, levels = pattern))

p5 <- ggplot(pat_df, aes(x = reorder(pattern, n), y = n, fill = n)) +
  geom_col(width = 0.65) +
  geom_text(aes(label = n), hjust = -0.1, size = 3.5) +
  coord_flip() +
  scale_fill_gradient(low = biof3_colors[["slate"]],
                      high = biof3_colors[["red"]], guide = "none") +
  labs(title = "DE gene overlap across strategies",
       subtitle = "Which DE genes are found by which set of strategies",
       x = NULL, y = "Number of genes") +
  theme_biof3()
save_plot(p5, "05-de-overlap.png", width = 10, height = 5.5)

# =============================================================================
# Part B: airway with / without cell as blocking factor
# =============================================================================
message("Part B: airway with vs without `cell` in design")
data("airway")
se <- airway
se$dex <- relevel(se$dex, ref = "untrt")

# without cell
dds_no <- DESeqDataSet(se, design = ~ dex)
dds_no <- dds_no[rowSums(counts(dds_no)) >= 10, ]
dds_no <- DESeq(dds_no)
res_no <- results(dds_no, contrast = c("dex", "trt", "untrt"))

# with cell
dds_yes <- DESeqDataSet(se, design = ~ cell + dex)
dds_yes <- dds_yes[rowSums(counts(dds_yes)) >= 10, ]
dds_yes <- DESeq(dds_yes)
res_yes <- results(dds_yes, contrast = c("dex", "trt", "untrt"))

cmp_df <- data.frame(
  strategy = c("~ dex only", "~ cell + dex"),
  n_padj05 = c(sum(res_no$padj < 0.05, na.rm = TRUE),
               sum(res_yes$padj < 0.05, na.rm = TRUE)),
  n_padj05_lfc1 = c(
    sum(res_no$padj < 0.05 & abs(res_no$log2FoldChange) > 1, na.rm = TRUE),
    sum(res_yes$padj < 0.05 & abs(res_yes$log2FoldChange) > 1, na.rm = TRUE)
  )
)
cmp_long <- cmp_df |>
  tidyr::pivot_longer(cols = -strategy, names_to = "threshold", values_to = "n") |>
  mutate(threshold = recode(threshold,
                            n_padj05 = "padj < 0.05",
                            n_padj05_lfc1 = "padj < 0.05 & |LFC|>1"),
         strategy = factor(strategy, levels = c("~ dex only", "~ cell + dex")))

p6 <- ggplot(cmp_long, aes(x = threshold, y = n, fill = strategy)) +
  geom_col(position = position_dodge(width = 0.7), width = 0.6) +
  geom_text(aes(label = n), position = position_dodge(width = 0.7),
            vjust = -0.4, size = 4) +
  scale_fill_manual(values = c(`~ dex only` = biof3_colors[["slate"]],
                               `~ cell + dex` = biof3_colors[["red"]])) +
  labs(title = "Airway: blocking on cell line recovers more DE genes",
       subtitle = "Same data, same contrast (trt vs untrt); design matters",
       x = NULL, y = "Number of DE genes",
       fill = "DESeq2 design") +
  theme_biof3()
save_plot(p6, "06-airway-design.png", width = 10, height = 5.5)

# ---- done ---------------------------------------------------------------
message("=== module bulk06 batch-effect figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Output dir: ", normalizePath(output_dir))
message("Done.")

sessionInfo()
