#!/usr/bin/env Rscript
# =============================================================================
# BioF3 Bulk RNA-seq Module 02: DESeq2 differential expression on airway
# =============================================================================
#
# This companion script for the bulk RNA-seq module 02:
#   1. Loads the Bioconductor airway data (8 samples, human airway epithelium,
#      dexamethasone treatment vs. control)
#   2. Runs the standard DESeq2 workflow: design -> DESeq() -> results()
#   3. Produces six figures covering library sizes, sample QC (PCA + sample-
#      to-sample distances), dispersion, p-value histogram, MA plot, and
#      a top DE gene count plot
#   4. Writes a table of DE genes to the output directory
#
# Data source:
#   Bioconductor airway package (no download required after install)
#   Himes et al. 2014, PLoS ONE 9(6): e99625
#
# Usage:
#   Rscript scripts/bulk02_deseq2_sci.R
#
# Optional env vars:
#   BIOF3_OUTPUT_DIR=/path/to/static/img/tutorial/bulk02
#
# Dependencies:
#   DESeq2, airway, ggplot2, dplyr, patchwork, RColorBrewer, ggrepel, pheatmap
# =============================================================================

options(stringsAsFactors = FALSE)
set.seed(42)

# ---- paths ---------------------------------------------------------------
args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/bulk02_deseq2_sci.R", mustWork = FALSE)
}
script_dir <- dirname(script_path)
project_root <- if (basename(script_dir) == "scripts" && basename(dirname(script_dir)) == "static") {
  dirname(dirname(script_dir))
} else {
  dirname(script_dir)
}
output_dir <- Sys.getenv(
  "BIOF3_OUTPUT_DIR",
  file.path(project_root, "static", "img", "tutorial", "bulk02")
)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

# ---- dependencies --------------------------------------------------------
suppressPackageStartupMessages({
  library(DESeq2)
  library(airway)
  library(ggplot2)
  library(dplyr)
  library(patchwork)
  library(RColorBrewer)
  library(ggrepel)
  library(pheatmap)
  library(grid)
})

biof3_colors <- c(
  green = "#0f766e", mint  = "#43d1ae", blue  = "#2563eb",
  amber = "#f59e0b", red   = "#dc2626", slate = "#475569",
  violet = "#7c3aed", pink = "#ec4899"
)

theme_biof3 <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text  = element_text(color = "#111827"),
      axis.title = element_text(color = "#111827", face = "bold"),
      plot.title = element_text(color = "#12342f", face = "bold", hjust = 0),
      plot.subtitle = element_text(color = "#526760", hjust = 0),
      legend.title = element_text(face = "bold"),
      legend.position = "right",
      strip.background = element_blank(),
      strip.text = element_text(color = "#12342f", face = "bold")
    )
}

save_plot <- function(p, name, width = 8, height = 5) {
  ggsave(file.path(output_dir, name), p, width = width, height = height, dpi = 300, bg = "white")
}

# ---- load data -----------------------------------------------------------
message("Loading airway data ...")
data("airway")
se <- airway
se$dex <- relevel(se$dex, ref = "untrt")

message("Samples: ", ncol(se), "  Genes: ", nrow(se))
message("Conditions: ", paste(levels(se$dex), collapse = " / "))

# ---- build DESeq object --------------------------------------------------
message("Building DESeqDataSet and running DESeq() ...")
dds <- DESeqDataSet(se, design = ~ cell + dex)

# Pre-filter: keep genes with at least 10 total reads (speeds up + robust)
keep <- rowSums(counts(dds)) >= 10
dds <- dds[keep, ]
message("After pre-filtering: ", nrow(dds), " genes")

dds <- DESeq(dds)

# ---- results -------------------------------------------------------------
res <- results(dds, contrast = c("dex", "trt", "untrt"))
message("Results summary:")
summary(res)

# Shrink LFC for nicer MA plot (apeglm)
res_shrink <- lfcShrink(dds, coef = "dex_trt_vs_untrt", type = "apeglm")

# ---- figure 1: library sizes --------------------------------------------
message("Figure 1: library sizes")
lib <- data.frame(
  sample = colnames(dds),
  condition = dds$dex,
  cell      = dds$cell,
  size_m    = colSums(counts(dds)) / 1e6
)
p1 <- ggplot(lib, aes(x = sample, y = size_m, fill = condition)) +
  geom_col(width = 0.7) +
  scale_fill_manual(values = c(untrt = biof3_colors["slate"], trt = biof3_colors["red"])) +
  labs(title = "Library size per sample",
       subtitle = "Total reads (million) after pre-filtering (>= 10 reads / gene)",
       x = NULL, y = "Reads (million)", fill = "Condition") +
  theme_biof3() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))
save_plot(p1, "01-library-size.png", width = 9, height = 5)

# ---- figure 2: PCA on VST ------------------------------------------------
message("Figure 2: PCA on VST")
vsd <- vst(dds, blind = FALSE)
pca <- plotPCA(vsd, intgroup = c("dex", "cell"), returnData = TRUE)
pct <- round(100 * attr(pca, "percentVar"))
p2 <- ggplot(pca, aes(x = PC1, y = PC2, color = dex, shape = cell)) +
  geom_point(size = 4, alpha = 0.9) +
  geom_text_repel(aes(label = name), size = 3, color = "#374151") +
  scale_color_manual(values = c(untrt = biof3_colors["slate"], trt = biof3_colors["red"])) +
  labs(title = "PCA on variance-stabilized counts",
       subtitle = "Color = dex treatment, shape = cell line; PC1 should separate conditions",
       x = paste0("PC1 (", pct[1], "%)"),
       y = paste0("PC2 (", pct[2], "%)"),
       color = "dex", shape = "cell") +
  theme_biof3()
save_plot(p2, "02-pca.png", width = 9, height = 6)

# ---- figure 3: sample-to-sample distance heatmap -------------------------
message("Figure 3: sample-to-sample distance heatmap")
sample_dist <- dist(t(assay(vsd)))
dist_mat <- as.matrix(sample_dist)
rownames(dist_mat) <- paste(vsd$cell, vsd$dex, sep = "_")
colnames(dist_mat) <- NULL
anno <- data.frame(
  dex  = vsd$dex,
  cell = vsd$cell,
  row.names = paste(vsd$cell, vsd$dex, sep = "_")
)
anno_colors <- list(
  dex  = c(untrt = biof3_colors[["slate"]], trt = biof3_colors[["red"]]),
  cell = setNames(
    c(biof3_colors[["green"]], biof3_colors[["blue"]],
      biof3_colors[["amber"]], biof3_colors[["violet"]]),
    levels(vsd$cell)
  )
)
png(file.path(output_dir, "03-sample-distance.png"),
    width = 9, height = 6, units = "in", res = 300, bg = "white")
pheatmap(
  dist_mat,
  clustering_distance_rows = sample_dist,
  clustering_distance_cols = sample_dist,
  annotation_row = anno,
  annotation_colors = anno_colors,
  col = colorRampPalette(rev(brewer.pal(9, "Blues")))(255),
  main = "Sample-to-sample distances (Euclidean on VST)",
  fontsize = 10
)
dev.off()

# ---- figure 4: dispersion estimates --------------------------------------
message("Figure 4: dispersion estimates")
disp_df <- data.frame(
  mean_counts = mcols(dds)$baseMean,
  disp_gene   = dispersions(dds),
  disp_fit    = mcols(dds)$dispFit,
  final_disp  = dispersions(dds)
)
disp_df <- disp_df[disp_df$mean_counts > 0, ]
p4 <- ggplot(disp_df, aes(x = mean_counts, y = disp_gene)) +
  geom_point(size = 0.6, alpha = 0.4, color = biof3_colors["slate"]) +
  geom_line(aes(y = disp_fit), color = biof3_colors["red"], linewidth = 1) +
  scale_x_log10() + scale_y_log10() +
  labs(title = "DESeq2 dispersion estimates",
       subtitle = "Grey = per-gene dispersion; red = fitted trend (mean-dispersion relation)",
       x = "Mean of normalized counts (log10)",
       y = "Dispersion (log10)") +
  theme_biof3()
save_plot(p4, "04-dispersion.png", width = 9, height = 5.5)

# ---- figure 5: p-value histogram + MA plot -------------------------------
message("Figure 5: p-value histogram + MA plot")
p_hist_df <- data.frame(pvalue = res$pvalue)
p_hist_df <- p_hist_df[!is.na(p_hist_df$pvalue), , drop = FALSE]
p5a <- ggplot(p_hist_df, aes(x = pvalue)) +
  geom_histogram(breaks = seq(0, 1, 0.025), fill = biof3_colors["blue"],
                 color = "white", linewidth = 0.2) +
  labs(title = "Raw p-value distribution",
       subtitle = "A flat tail and a peak near 0 = well-behaved test",
       x = "Raw p-value", y = "Number of genes") +
  theme_biof3()

ma_df <- data.frame(
  baseMean = res_shrink$baseMean,
  log2FC   = res_shrink$log2FoldChange,
  padj     = res$padj
)
ma_df$sig <- with(ma_df, ifelse(!is.na(padj) & padj < 0.05 & abs(log2FC) > 1,
                                "DE (|LFC|>1, padj<0.05)", "not DE"))
p5b <- ggplot(ma_df, aes(x = baseMean, y = log2FC, color = sig)) +
  geom_point(size = 0.4, alpha = 0.6) +
  geom_hline(yintercept = 0, color = "grey40", linewidth = 0.3) +
  scale_x_log10() +
  scale_color_manual(values = c(`DE (|LFC|>1, padj<0.05)` = biof3_colors[["red"]],
                                `not DE` = "#9ca3af")) +
  labs(title = "MA plot (apeglm-shrunken log2FC)",
       subtitle = "x = mean expression; y = log2 fold change (trt vs untrt)",
       x = "Mean of normalized counts (log10)",
       y = "log2 fold change",
       color = NULL) +
  theme_biof3() + theme(legend.position = "bottom")
p5 <- (p5a | p5b) + plot_annotation(
  title = "Global QC of DESeq2 results",
  theme = theme(plot.title = element_text(face = "bold", color = "#12342f"))
)
save_plot(p5, "05-pvalue-ma.png", width = 14, height = 5.5)

# ---- figure 6: top DE genes box / count plot ----------------------------
message("Figure 6: top DE genes (normalized counts)")
# pick 6 genes with strongest padj
ord <- order(res$padj)
top_ids <- rownames(res)[head(ord, 6)]
count_long <- lapply(top_ids, function(g) {
  d <- plotCounts(dds, gene = g, intgroup = c("dex", "cell"), returnData = TRUE)
  d$gene <- g
  d
}) |> bind_rows()
p6 <- ggplot(count_long, aes(x = dex, y = count, color = cell)) +
  geom_jitter(width = 0.15, height = 0, size = 2.2, alpha = 0.9) +
  facet_wrap(~ gene, scales = "free_y", ncol = 3) +
  scale_y_log10() +
  scale_color_manual(values = c(
    N61311 = biof3_colors[["green"]], N052611 = biof3_colors[["blue"]],
    N080611 = biof3_colors[["amber"]], N061011 = biof3_colors[["violet"]]
  )) +
  labs(title = "Top 6 DE genes by padj",
       subtitle = "Normalized counts (log10); each cell line shown separately",
       x = NULL, y = "Normalized counts (log10)", color = "Cell line") +
  theme_biof3()
save_plot(p6, "06-top-genes.png", width = 12, height = 7)

# ---- write DE table -----------------------------------------------------
de_tbl <- as.data.frame(res_shrink) |>
  tibble::rownames_to_column("gene_id") |>
  mutate(padj_raw = res$padj) |>
  arrange(padj_raw)
de_path <- file.path(output_dir, "de_genes_airway.tsv")
write.table(de_tbl, de_path, sep = "\t", quote = FALSE, row.names = FALSE)
message("Wrote DE table to ", de_path)

# ---- done ---------------------------------------------------------------
message("=== module bulk02 DESeq2 figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Output dir: ", normalizePath(output_dir))
message("Done.")

sessionInfo()
