#!/usr/bin/env Rscript
# =============================================================================
# BioF3 Bulk RNA-seq Module 04: Volcano, heatmap and enrichment visualization
# =============================================================================
#
# This companion script for bulk RNA-seq module 04:
#   1. Re-runs the DESeq2 + enrichment pipelines from bulk02 / bulk03
#   2. Produces six "paper-ready" figures:
#      - Volcano plot with labeled top genes
#      - Heatmap of top DE genes across samples
#      - MA plot with annotated extremes
#      - Bubble plot of KEGG enrichment
#      - Chord plot of gene-to-term relationships for top BP terms
#      - Forest-style plot of GSEA NES + CI
#
# Usage:
#   Rscript scripts/bulk04_visualization_sci.R
#
# Optional env vars:
#   BIOF3_OUTPUT_DIR=/path/to/static/img/tutorial/bulk04
#
# Dependencies:
#   DESeq2, airway, clusterProfiler, org.Hs.eg.db, enrichplot, DOSE,
#   ggplot2, dplyr, patchwork, ggrepel, pheatmap, RColorBrewer, stringr
# =============================================================================

options(stringsAsFactors = FALSE)
set.seed(42)

# ---- paths ---------------------------------------------------------------
args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/bulk04_visualization_sci.R", mustWork = FALSE)
}
script_dir <- dirname(script_path)
project_root <- if (basename(script_dir) == "scripts" && basename(dirname(script_dir)) == "static") {
  dirname(dirname(script_dir))
} else {
  dirname(script_dir)
}
output_dir <- Sys.getenv(
  "BIOF3_OUTPUT_DIR",
  file.path(project_root, "static", "img", "tutorial", "bulk04")
)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

# ---- dependencies --------------------------------------------------------
suppressPackageStartupMessages({
  library(DESeq2)
  library(airway)
  library(clusterProfiler)
  library(org.Hs.eg.db)
  library(enrichplot)
  library(DOSE)
  library(ggplot2)
  library(dplyr)
  library(patchwork)
  library(ggrepel)
  library(pheatmap)
  library(RColorBrewer)
  library(stringr)
})

biof3_colors <- c(
  green = "#0f766e", mint  = "#43d1ae", blue  = "#2563eb",
  amber = "#f59e0b", red   = "#dc2626", slate = "#475569",
  violet = "#7c3aed", pink = "#ec4899"
)

theme_biof3 <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text  = element_text(color = "#111827"),
      axis.title = element_text(color = "#111827", face = "bold"),
      plot.title = element_text(color = "#12342f", face = "bold", hjust = 0),
      plot.subtitle = element_text(color = "#526760", hjust = 0),
      legend.title = element_text(face = "bold"),
      legend.position = "right",
      strip.background = element_blank(),
      strip.text = element_text(color = "#12342f", face = "bold")
    )
}

save_plot <- function(p, name, width = 8, height = 5) {
  ggsave(file.path(output_dir, name), p, width = width, height = height, dpi = 300, bg = "white")
}

# ---- DESeq2 + LFC shrink ------------------------------------------------
message("Running DESeq2 on airway ...")
data("airway")
se <- airway
se$dex <- relevel(se$dex, ref = "untrt")
dds <- DESeqDataSet(se, design = ~ cell + dex)
dds <- dds[rowSums(counts(dds)) >= 10, ]
dds <- DESeq(dds)
res <- results(dds, contrast = c("dex", "trt", "untrt"))
res_shrink <- lfcShrink(dds, coef = "dex_trt_vs_untrt", type = "apeglm")

# ---- map to symbol -------------------------------------------------------
message("Mapping Ensembl to SYMBOL / ENTREZID ...")
id_map <- AnnotationDbi::select(
  org.Hs.eg.db,
  keys = rownames(res), keytype = "ENSEMBL",
  columns = c("ENTREZID", "SYMBOL")
) |>
  distinct(ENSEMBL, .keep_all = TRUE) |>
  filter(!is.na(SYMBOL))

de_df <- as.data.frame(res_shrink) |>
  tibble::rownames_to_column("ensembl") |>
  left_join(id_map, by = c("ensembl" = "ENSEMBL")) |>
  filter(!is.na(padj))
de_df$padj_raw <- res[de_df$ensembl, "padj"]
message("Genes with symbol: ", sum(!is.na(de_df$SYMBOL)))

# ---- figure 1: volcano plot ---------------------------------------------
message("Figure 1: Volcano plot")
vol <- de_df |>
  mutate(
    sig = case_when(
      padj_raw < 0.05 & log2FoldChange >  1 ~ "Up",
      padj_raw < 0.05 & log2FoldChange < -1 ~ "Down",
      TRUE ~ "NS"
    ),
    neglogp = -log10(pmax(padj_raw, 1e-300))
  ) |>
  filter(!is.na(padj_raw))

# label top 10 up and top 10 down
label_df <- bind_rows(
  vol |> filter(sig == "Up")   |> arrange(desc(neglogp)) |> head(10),
  vol |> filter(sig == "Down") |> arrange(desc(neglogp)) |> head(10)
) |>
  filter(!is.na(SYMBOL))

p1 <- ggplot(vol, aes(x = log2FoldChange, y = neglogp, color = sig)) +
  geom_point(size = 0.6, alpha = 0.5) +
  geom_text_repel(data = label_df, aes(label = SYMBOL), size = 3.3,
                  max.overlaps = 30, seed = 42, color = "#111827") +
  scale_color_manual(values = c(Up = biof3_colors[["red"]],
                                Down = biof3_colors[["blue"]],
                                NS = "#9ca3af")) +
  geom_vline(xintercept = c(-1, 1), linetype = "dashed", color = "grey50") +
  geom_hline(yintercept = -log10(0.05), linetype = "dashed", color = "grey50") +
  labs(title = "Volcano plot: dex-treated vs control",
       subtitle = "Top 10 up and top 10 down genes labeled; dashed lines at LFC=1 and padj=0.05",
       x = "log2 fold change (shrunken)",
       y = "-log10(padj)", color = NULL) +
  theme_biof3()
save_plot(p1, "01-volcano.png", width = 10, height = 7)

# ---- figure 2: top DE gene heatmap --------------------------------------
message("Figure 2: heatmap of top DE genes")
vsd <- vst(dds, blind = FALSE)
top_sig <- de_df |>
  filter(!is.na(padj_raw)) |>
  arrange(padj_raw) |>
  filter(!is.na(SYMBOL)) |>
  distinct(SYMBOL, .keep_all = TRUE) |>
  head(40)

# z-score normalize per gene
mat <- assay(vsd)[top_sig$ensembl, ]
rownames(mat) <- top_sig$SYMBOL
mat_z <- t(scale(t(mat)))

anno_col <- data.frame(
  dex  = colData(vsd)$dex,
  cell = colData(vsd)$cell,
  row.names = colnames(vsd)
)
anno_colors <- list(
  dex  = c(untrt = biof3_colors[["slate"]], trt = biof3_colors[["red"]]),
  cell = setNames(
    c(biof3_colors[["green"]], biof3_colors[["blue"]],
      biof3_colors[["amber"]], biof3_colors[["violet"]]),
    levels(colData(vsd)$cell)
  )
)

png(file.path(output_dir, "02-heatmap-top40.png"),
    width = 10, height = 10, units = "in", res = 300, bg = "white")
pheatmap(
  mat_z,
  color = colorRampPalette(c(biof3_colors[["blue"]], "white", biof3_colors[["red"]]))(100),
  breaks = seq(-3, 3, length.out = 101),
  annotation_col = anno_col,
  annotation_colors = anno_colors,
  show_colnames = FALSE, fontsize_row = 8,
  main = "Top 40 DE genes (row z-score of VST)"
)
dev.off()

# ---- figure 3: counts boxplot for a curated gene set --------------------
message("Figure 3: curated gene panel expression")
# classic glucocorticoid response: DUSP1, KLF15, ANGPTL4, FKBP5 etc.
panel_symbols <- c("DUSP1", "FKBP5", "ANGPTL4", "KLF15", "SPARCL1", "ZBTB16")
panel_df <- de_df |> filter(SYMBOL %in% panel_symbols) |>
  distinct(SYMBOL, .keep_all = TRUE)

if (nrow(panel_df) > 0) {
  # extract normalized counts for these genes
  count_list <- lapply(panel_df$ensembl, function(g) {
    d <- plotCounts(dds, gene = g, intgroup = c("dex", "cell"), returnData = TRUE)
    d$symbol <- panel_df$SYMBOL[panel_df$ensembl == g]
    d
  })
  count_long <- bind_rows(count_list)
  count_long$symbol <- factor(count_long$symbol,
                              levels = panel_df$SYMBOL[order(panel_df$padj_raw)])

  p3 <- ggplot(count_long, aes(x = dex, y = count, fill = dex)) +
    geom_boxplot(width = 0.5, outlier.shape = NA, alpha = 0.85) +
    geom_jitter(width = 0.12, height = 0, size = 1.5, alpha = 0.8) +
    facet_wrap(~ symbol, scales = "free_y", ncol = 3) +
    scale_fill_manual(values = c(untrt = biof3_colors[["slate"]],
                                 trt = biof3_colors[["red"]]), guide = "none") +
    scale_y_log10() +
    labs(title = "Curated glucocorticoid response genes",
         subtitle = "Normalized counts (log10) across 8 samples; boxes span IQR",
         x = NULL, y = "Normalized counts (log10)") +
    theme_biof3()
  save_plot(p3, "03-curated-panel.png", width = 11, height = 6.5)
}

# ---- figure 4: KEGG bubble plot (from bulk03 pipeline) ------------------
message("Figure 4: KEGG bubble plot")
res_entrez <- de_df |> filter(!is.na(ENTREZID)) |> distinct(ENTREZID, .keep_all = TRUE)
sig_entrez <- res_entrez |> filter(padj_raw < 0.05 & abs(log2FoldChange) > 1) |> pull(ENTREZID)

ekegg <- tryCatch(
  enrichKEGG(gene = sig_entrez, organism = "hsa",
             universe = res_entrez$ENTREZID,
             pAdjustMethod = "BH",
             pvalueCutoff = 0.1, qvalueCutoff = 0.2),
  error = function(e) { message("KEGG failed: ", conditionMessage(e)); NULL }
)

if (!is.null(ekegg) && nrow(ekegg@result) > 0) {
  ekegg_df <- ekegg@result |>
    arrange(p.adjust) |>
    head(15) |>
    mutate(
      Description = str_trunc(Description, 50),
      GeneRatio_num = sapply(GeneRatio, function(x) eval(parse(text = x))),
      neglogp = -log10(p.adjust)
    )
  p4 <- ggplot(ekegg_df,
               aes(x = GeneRatio_num, y = reorder(Description, GeneRatio_num),
                   size = Count, color = neglogp)) +
    geom_point() +
    scale_color_gradient(low = biof3_colors[["blue"]],
                         high = biof3_colors[["red"]], name = "-log10(p.adj)") +
    scale_size_continuous(range = c(3, 9), name = "Gene count") +
    labs(title = "KEGG enrichment bubble plot",
         subtitle = "Top 15 pathways; bubble size = gene count, color = -log10(p.adj)",
         x = "Gene ratio", y = NULL) +
    theme_biof3()
  save_plot(p4, "04-kegg-bubble.png", width = 11, height = 7)
}

# ---- figure 5: cnetplot (gene-to-term network) for top GO BP ------------
message("Figure 5: cnetplot (gene-term network)")
ego <- enrichGO(
  gene = sig_entrez, universe = res_entrez$ENTREZID,
  OrgDb = org.Hs.eg.db, keyType = "ENTREZID",
  ont = "BP", pAdjustMethod = "BH",
  pvalueCutoff = 0.05, qvalueCutoff = 0.1, readable = TRUE
)
ego_s <- tryCatch(simplify(ego, cutoff = 0.7, by = "p.adjust"),
                  error = function(x) ego)

# we need fold change info for coloring genes
fc_vec <- setNames(de_df$log2FoldChange, de_df$SYMBOL)
fc_vec <- fc_vec[!is.na(fc_vec) & !is.na(names(fc_vec))]

p5 <- cnetplot(ego_s,
               showCategory = 6,
               foldChange = fc_vec) +
  scale_color_gradient2(low = biof3_colors[["blue"]], mid = "grey90",
                        high = biof3_colors[["red"]], midpoint = 0,
                        name = "log2 FC") +
  labs(title = "Gene-term network (cnetplot)",
       subtitle = "Top 6 GO BP terms and the DE genes driving them") +
  theme_biof3() +
  theme(axis.text = element_blank(), axis.ticks = element_blank(),
        axis.title = element_blank(), axis.line = element_blank())
save_plot(p5, "05-cnetplot.png", width = 12, height = 9)

# ---- figure 6: GSEA NES forest plot ------------------------------------
message("Figure 6: GSEA NES forest")
rnk <- de_df |>
  filter(!is.na(ENTREZID) & !is.na(pvalue) & pvalue > 0) |>
  mutate(stat = sign(log2FoldChange) * -log10(pvalue)) |>
  arrange(desc(stat))
rank_vec <- rnk$stat
names(rank_vec) <- rnk$ENTREZID
rank_vec <- rank_vec[!duplicated(names(rank_vec))]

gsea_go <- gseGO(
  geneList = rank_vec, OrgDb = org.Hs.eg.db, ont = "BP", keyType = "ENTREZID",
  pAdjustMethod = "BH", pvalueCutoff = 0.25,
  minGSSize = 15, maxGSSize = 500, verbose = FALSE, seed = TRUE
)

gsea_df <- gsea_go@result |>
  arrange(NES)
top_nes <- bind_rows(
  gsea_df |> arrange(desc(NES)) |> head(10),
  gsea_df |> arrange(NES) |> head(10)
) |>
  distinct(ID, .keep_all = TRUE) |>
  mutate(Description = str_trunc(Description, 55),
         direction = ifelse(NES > 0, "Activated", "Suppressed"),
         neglogp = -log10(p.adjust))

p6 <- ggplot(top_nes, aes(x = NES, y = reorder(Description, NES),
                          color = direction, size = neglogp)) +
  geom_segment(aes(x = 0, xend = NES, yend = Description),
               linewidth = 0.3, color = "grey70") +
  geom_point() +
  scale_color_manual(values = c(Activated = biof3_colors[["red"]],
                                Suppressed = biof3_colors[["blue"]])) +
  scale_size_continuous(range = c(3, 8), name = "-log10(p.adj)") +
  labs(title = "GSEA NES forest plot (GO BP)",
       subtitle = "Top 10 activated and top 10 suppressed terms",
       x = "Normalized Enrichment Score (NES)",
       y = NULL, color = NULL) +
  theme_biof3()
save_plot(p6, "06-gsea-forest.png", width = 12, height = 8)

# ---- done ---------------------------------------------------------------
message("=== module bulk04 visualization figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Output dir: ", normalizePath(output_dir))
message("Done.")

sessionInfo()
