#!/usr/bin/env Rscript
# =============================================================================
# BioF3 Bulk RNA-seq Module 07: DESeq2 vs edgeR vs limma-voom comparison
# =============================================================================
#
# This companion script for bulk RNA-seq module 07:
#   1. Runs the same DE comparison (trt vs untrt, blocked by cell) on
#      airway through three pipelines:
#      - DESeq2 (Wald)
#      - edgeR (glmQLFTest)
#      - limma-voom
#   2. Produces six figures:
#      - P-value histogram for each tool (3 panels)
#      - DE gene counts at different padj thresholds
#      - Venn-style overlap of DE gene sets
#      - log2 fold change scatter (DESeq2 vs edgeR, DESeq2 vs limma)
#      - Rank correlation between tools
#      - Effect-size distribution for tool-specific genes
#
# Usage:
#   Rscript scripts/bulk07_tools_sci.R
#
# Optional env vars:
#   BIOF3_OUTPUT_DIR=/path/to/static/img/tutorial/bulk07
# =============================================================================

options(stringsAsFactors = FALSE)
set.seed(42)

# ---- paths ---------------------------------------------------------------
args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/bulk07_tools_sci.R", mustWork = FALSE)
}
script_dir <- dirname(script_path)
project_root <- if (basename(script_dir) == "scripts" && basename(dirname(script_dir)) == "static") {
  dirname(dirname(script_dir))
} else {
  dirname(script_dir)
}
output_dir <- Sys.getenv(
  "BIOF3_OUTPUT_DIR",
  file.path(project_root, "static", "img", "tutorial", "bulk07")
)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

# ---- dependencies --------------------------------------------------------
suppressPackageStartupMessages({
  library(DESeq2)
  library(edgeR)
  library(limma)
  library(airway)
  library(ggplot2)
  library(dplyr)
  library(patchwork)
  library(tidyr)
})

biof3_colors <- c(
  green = "#0f766e", mint  = "#43d1ae", blue  = "#2563eb",
  amber = "#f59e0b", red   = "#dc2626", slate = "#475569",
  violet = "#7c3aed", pink = "#ec4899"
)

theme_biof3 <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text  = element_text(color = "#111827"),
      axis.title = element_text(color = "#111827", face = "bold"),
      plot.title = element_text(color = "#12342f", face = "bold", hjust = 0),
      plot.subtitle = element_text(color = "#526760", hjust = 0),
      legend.title = element_text(face = "bold"),
      legend.position = "right",
      strip.background = element_blank(),
      strip.text = element_text(color = "#12342f", face = "bold")
    )
}

save_plot <- function(p, name, width = 8, height = 5) {
  ggsave(file.path(output_dir, name), p, width = width, height = height, dpi = 300, bg = "white")
}

# ---- load airway + pre-filter --------------------------------------------
message("Loading airway ...")
data("airway")
se <- airway
se$dex <- relevel(se$dex, ref = "untrt")
cnt <- assay(se)
keep <- rowSums(cnt) >= 10
cnt <- cnt[keep, ]
coldata <- colData(se)
message("Genes after filter: ", nrow(cnt), "  Samples: ", ncol(cnt))

# ---- DESeq2 --------------------------------------------------------------
message("Running DESeq2 ...")
dds <- DESeqDataSetFromMatrix(cnt, colData = coldata, design = ~ cell + dex)
dds <- DESeq(dds)
res_deseq <- results(dds, contrast = c("dex", "trt", "untrt"))
deseq_df <- as.data.frame(res_deseq) |>
  tibble::rownames_to_column("gene") |>
  filter(!is.na(padj)) |>
  select(gene, deseq_lfc = log2FoldChange, deseq_pvalue = pvalue, deseq_padj = padj)

# ---- edgeR ---------------------------------------------------------------
message("Running edgeR ...")
design_mat <- model.matrix(~ cell + dex, data = as.data.frame(coldata))
dge <- DGEList(cnt)
dge <- calcNormFactors(dge, method = "TMM")
dge <- estimateDisp(dge, design_mat)
fit_edg <- glmQLFit(dge, design_mat)
qlf <- glmQLFTest(fit_edg, coef = "dextrt")
tt_edg <- topTags(qlf, n = Inf, sort.by = "none")$table
edg_df <- data.frame(
  gene         = rownames(tt_edg),
  edger_lfc    = tt_edg$logFC,
  edger_pvalue = tt_edg$PValue,
  edger_padj   = tt_edg$FDR
)

# ---- limma-voom ---------------------------------------------------------
message("Running limma-voom ...")
v <- voom(dge, design_mat, plot = FALSE)
fit_v <- lmFit(v, design_mat) |> eBayes()
tt_lim <- topTable(fit_v, coef = "dextrt", number = Inf, sort.by = "none")
lim_df <- data.frame(
  gene         = rownames(tt_lim),
  limma_lfc    = tt_lim$logFC,
  limma_pvalue = tt_lim$P.Value,
  limma_padj   = tt_lim$adj.P.Val
)

# ---- merge ---------------------------------------------------------------
all_df <- deseq_df |>
  inner_join(edg_df, by = "gene") |>
  inner_join(lim_df, by = "gene")
message("Genes shared by all three tools: ", nrow(all_df))

# ---- figure 1: p-value histogram per tool -------------------------------
message("Figure 1: p-value histograms")
hist_df <- all_df |>
  select(gene, deseq_pvalue, edger_pvalue, limma_pvalue) |>
  pivot_longer(-gene, names_to = "tool", values_to = "pvalue") |>
  mutate(tool = recode(tool,
                       deseq_pvalue = "DESeq2",
                       edger_pvalue = "edgeR",
                       limma_pvalue = "limma-voom"))

p1 <- ggplot(hist_df, aes(x = pvalue, fill = tool)) +
  geom_histogram(breaks = seq(0, 1, 0.025), color = "white", linewidth = 0.2) +
  facet_wrap(~ tool, ncol = 3, scales = "free_y") +
  scale_fill_manual(values = c(DESeq2 = biof3_colors[["green"]],
                               edgeR  = biof3_colors[["blue"]],
                               `limma-voom` = biof3_colors[["violet"]]),
                    guide = "none") +
  labs(title = "Raw p-value distribution per tool",
       subtitle = "Healthy: peak at 0 + flat tail. Three tools agree on overall shape here.",
       x = "Raw p-value", y = "Number of genes") +
  theme_biof3()
save_plot(p1, "01-pvalue-hist.png", width = 13, height = 5)

# ---- figure 2: DE counts at different padj thresholds -------------------
message("Figure 2: DE counts at multiple thresholds")
thresholds <- c(0.001, 0.01, 0.05, 0.1)
cnt_df <- bind_rows(lapply(thresholds, function(th) {
  data.frame(
    threshold = th,
    DESeq2 = sum(all_df$deseq_padj < th),
    edgeR  = sum(all_df$edger_padj < th),
    `limma-voom` = sum(all_df$limma_padj < th),
    check.names = FALSE
  )
}))
cnt_long <- cnt_df |>
  pivot_longer(-threshold, names_to = "tool", values_to = "n") |>
  mutate(threshold = factor(threshold, levels = thresholds))

p2 <- ggplot(cnt_long, aes(x = threshold, y = n, fill = tool)) +
  geom_col(position = position_dodge(width = 0.7), width = 0.65) +
  geom_text(aes(label = n), position = position_dodge(width = 0.7),
            vjust = -0.3, size = 3.2) +
  scale_fill_manual(values = c(DESeq2 = biof3_colors[["green"]],
                               edgeR  = biof3_colors[["blue"]],
                               `limma-voom` = biof3_colors[["violet"]])) +
  labs(title = "Number of DE genes at each padj threshold",
       subtitle = "Three tools give numbers within ~10% of each other on airway",
       x = "padj threshold", y = "Number of DE genes",
       fill = "Tool") +
  theme_biof3()
save_plot(p2, "02-de-counts.png", width = 11, height = 5.5)

# ---- figure 3: DE set overlap -------------------------------------------
message("Figure 3: DE gene set overlap (padj < 0.05)")
sets <- list(
  DESeq2 = all_df$gene[all_df$deseq_padj < 0.05],
  edgeR  = all_df$gene[all_df$edger_padj < 0.05],
  `limma-voom` = all_df$gene[all_df$limma_padj < 0.05]
)

all_de <- unique(unlist(sets))
in_mat <- sapply(sets, function(x) all_de %in% x)
pattern <- apply(in_mat, 1, function(x) paste(names(sets)[x], collapse = " + "))
counts <- sort(table(pattern), decreasing = TRUE)
pat_df <- data.frame(pattern = names(counts), n = as.integer(counts))

p3 <- ggplot(pat_df, aes(x = reorder(pattern, n), y = n, fill = n)) +
  geom_col(width = 0.7) +
  geom_text(aes(label = n), hjust = -0.1, size = 3.5) +
  coord_flip() +
  scale_fill_gradient(low = biof3_colors[["slate"]],
                      high = biof3_colors[["red"]], guide = "none") +
  labs(title = "DE gene overlap across tools (padj < 0.05)",
       subtitle = "Large 'DESeq2 + edgeR + limma-voom' bar = high concordance",
       x = NULL, y = "Number of genes") +
  theme_biof3()
save_plot(p3, "03-overlap.png", width = 10, height = 5)

# ---- figure 4: log2FC scatter -------------------------------------------
message("Figure 4: log2FC scatter (pairwise)")
lfc_df <- all_df |>
  select(gene, DESeq2 = deseq_lfc, edgeR = edger_lfc, `limma-voom` = limma_lfc)

p4a <- ggplot(lfc_df, aes(x = DESeq2, y = edgeR)) +
  geom_point(size = 0.4, alpha = 0.3, color = biof3_colors["blue"]) +
  geom_abline(slope = 1, intercept = 0, color = biof3_colors["red"],
              linetype = "dashed", linewidth = 0.5) +
  labs(title = paste0("DESeq2 vs edgeR (r = ",
                      round(cor(lfc_df$DESeq2, lfc_df$edgeR), 3), ")"),
       x = "DESeq2 log2FC", y = "edgeR log2FC") +
  theme_biof3()

p4b <- ggplot(lfc_df, aes(x = DESeq2, y = `limma-voom`)) +
  geom_point(size = 0.4, alpha = 0.3, color = biof3_colors["violet"]) +
  geom_abline(slope = 1, intercept = 0, color = biof3_colors["red"],
              linetype = "dashed", linewidth = 0.5) +
  labs(title = paste0("DESeq2 vs limma-voom (r = ",
                      round(cor(lfc_df$DESeq2, lfc_df$`limma-voom`), 3), ")"),
       x = "DESeq2 log2FC", y = "limma-voom log2FC") +
  theme_biof3()

p4 <- (p4a | p4b) + plot_annotation(
  title = "log2 fold change agreement across tools",
  subtitle = "Dashed red = identity line; tight diagonal = tools agree on effect size",
  theme = theme(plot.title = element_text(face = "bold", color = "#12342f"),
                plot.subtitle = element_text(color = "#526760"))
)
save_plot(p4, "04-lfc-scatter.png", width = 12, height = 5.5)

# ---- figure 5: rank correlation of p-values ----------------------------
message("Figure 5: rank correlation of pvalues")
rank_df <- all_df |>
  mutate(
    deseq_rank = rank(deseq_pvalue),
    edger_rank = rank(edger_pvalue),
    limma_rank = rank(limma_pvalue)
  )
rc <- data.frame(
  pair = c("DESeq2 vs edgeR", "DESeq2 vs limma-voom", "edgeR vs limma-voom"),
  spearman = c(
    cor(rank_df$deseq_rank, rank_df$edger_rank, method = "spearman"),
    cor(rank_df$deseq_rank, rank_df$limma_rank, method = "spearman"),
    cor(rank_df$edger_rank, rank_df$limma_rank, method = "spearman")
  )
)

p5 <- ggplot(rc, aes(x = pair, y = spearman, fill = pair)) +
  geom_col(width = 0.5) +
  geom_text(aes(label = sprintf("%.3f", spearman)), vjust = -0.4, size = 4.5) +
  coord_cartesian(ylim = c(0.9, 1.01)) +
  scale_fill_manual(values = unname(biof3_colors)[2:4], guide = "none") +
  labs(title = "Spearman rank correlation of p-values",
       subtitle = "The ranking of genes by significance is essentially the same across tools",
       x = NULL, y = "Spearman correlation") +
  theme_biof3() +
  theme(axis.text.x = element_text(angle = 30, hjust = 1))
save_plot(p5, "05-rank-cor.png", width = 9, height = 5.5)

# ---- figure 6: tool-unique genes: where do they disagree? ---------------
message("Figure 6: tool-unique gene characteristics")
# find DESeq2-unique genes at padj<0.05 vs union of the other two
deseq_only <- setdiff(sets$DESeq2, union(sets$edgeR, sets$`limma-voom`))
edger_only <- setdiff(sets$edgeR, union(sets$DESeq2, sets$`limma-voom`))
limma_only <- setdiff(sets$`limma-voom`, union(sets$DESeq2, sets$edgeR))

# baseMean for each category
get_bm <- function(genes) all_df$deseq_lfc[all_df$gene %in% genes] |> abs()
bm_df <- bind_rows(
  data.frame(category = "DESeq2 only",     abs_lfc = get_bm(deseq_only)),
  data.frame(category = "edgeR only",      abs_lfc = get_bm(edger_only)),
  data.frame(category = "limma-voom only", abs_lfc = get_bm(limma_only)),
  data.frame(category = "Shared by all",
             abs_lfc = get_bm(Reduce(intersect, sets)))
) |>
  mutate(category = factor(category,
                           levels = c("DESeq2 only", "edgeR only",
                                      "limma-voom only", "Shared by all")))

p6 <- ggplot(bm_df, aes(x = category, y = abs_lfc, fill = category)) +
  geom_violin(width = 0.8, alpha = 0.8, color = NA) +
  geom_boxplot(width = 0.15, fill = "white", alpha = 0.8, outlier.size = 0.3) +
  scale_fill_manual(values = c(`DESeq2 only` = biof3_colors[["green"]],
                               `edgeR only`  = biof3_colors[["blue"]],
                               `limma-voom only` = biof3_colors[["violet"]],
                               `Shared by all` = biof3_colors[["red"]]),
                    guide = "none") +
  labs(title = "|log2 fold change| distribution by DE category",
       subtitle = "Tool-unique genes typically have smaller effect sizes — borderline calls",
       x = NULL, y = "|log2 fold change| (DESeq2 estimate)") +
  theme_biof3() +
  theme(axis.text.x = element_text(angle = 20, hjust = 1))
save_plot(p6, "06-unique-genes.png", width = 11, height = 6)

# ---- done ---------------------------------------------------------------
message("=== module bulk07 tool-comparison figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Output dir: ", normalizePath(output_dir))
message("Done.")

sessionInfo()
