#!/usr/bin/env Rscript
# =============================================================================
# BioF3 Bulk RNA-seq Module 03: GO / KEGG / GSEA on airway DE results
# =============================================================================
#
# This companion script for the bulk RNA-seq module 03:
#   1. Loads the Bioconductor airway data and runs a quick DESeq2 pass
#      (same inputs as bulk02)
#   2. Maps Ensembl IDs to Entrez IDs for clusterProfiler
#   3. Runs GO (BP/MF/CC), KEGG over-representation on significant DE genes
#   4. Runs GO and KEGG GSEA on the full ranked gene list
#   5. Produces six figures: dotplot, barplot, GO emap, GSEA waterfall,
#      GSEA running-score, ridge plot
#
# Data source:
#   Bioconductor airway package (no download required after install)
#
# Usage:
#   Rscript scripts/bulk03_enrichment_sci.R
#
# Optional env vars:
#   BIOF3_OUTPUT_DIR=/path/to/static/img/tutorial/bulk03
#
# Dependencies:
#   DESeq2, airway, clusterProfiler, org.Hs.eg.db, enrichplot, DOSE,
#   ggplot2, dplyr, patchwork, ggrepel, stringr
# =============================================================================

options(stringsAsFactors = FALSE)
set.seed(42)

# ---- paths ---------------------------------------------------------------
args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/bulk03_enrichment_sci.R", mustWork = FALSE)
}
script_dir <- dirname(script_path)
project_root <- if (basename(script_dir) == "scripts" && basename(dirname(script_dir)) == "static") {
  dirname(dirname(script_dir))
} else {
  dirname(script_dir)
}
output_dir <- Sys.getenv(
  "BIOF3_OUTPUT_DIR",
  file.path(project_root, "static", "img", "tutorial", "bulk03")
)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

# ---- dependencies --------------------------------------------------------
suppressPackageStartupMessages({
  library(DESeq2)
  library(airway)
  library(clusterProfiler)
  library(org.Hs.eg.db)
  library(enrichplot)
  library(DOSE)
  library(ggplot2)
  library(dplyr)
  library(patchwork)
  library(stringr)
})

biof3_colors <- c(
  green = "#0f766e", mint  = "#43d1ae", blue  = "#2563eb",
  amber = "#f59e0b", red   = "#dc2626", slate = "#475569",
  violet = "#7c3aed", pink = "#ec4899"
)

theme_biof3 <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text  = element_text(color = "#111827"),
      axis.title = element_text(color = "#111827", face = "bold"),
      plot.title = element_text(color = "#12342f", face = "bold", hjust = 0),
      plot.subtitle = element_text(color = "#526760", hjust = 0),
      legend.title = element_text(face = "bold"),
      legend.position = "right",
      strip.background = element_blank(),
      strip.text = element_text(color = "#12342f", face = "bold")
    )
}

save_plot <- function(p, name, width = 8, height = 5) {
  ggsave(file.path(output_dir, name), p, width = width, height = height, dpi = 300, bg = "white")
}

# ---- quick DESeq2 pass ---------------------------------------------------
message("Running DESeq2 on airway ...")
data("airway")
se <- airway
se$dex <- relevel(se$dex, ref = "untrt")
dds <- DESeqDataSet(se, design = ~ cell + dex)
dds <- dds[rowSums(counts(dds)) >= 10, ]
dds <- DESeq(dds)
res <- results(dds, contrast = c("dex", "trt", "untrt"))
res_df <- as.data.frame(res) |>
  tibble::rownames_to_column("ensembl") |>
  filter(!is.na(padj))
message("DE genes (padj < 0.05 & |LFC| > 1): ",
        sum(res_df$padj < 0.05 & abs(res_df$log2FoldChange) > 1))

# ---- map Ensembl to Entrez ----------------------------------------------
message("Mapping Ensembl IDs to Entrez IDs ...")
# airway uses Ensembl IDs like "ENSG00000000003"
id_map <- AnnotationDbi::select(
  org.Hs.eg.db,
  keys     = res_df$ensembl,
  keytype  = "ENSEMBL",
  columns  = c("ENTREZID", "SYMBOL")
) |>
  distinct(ENSEMBL, .keep_all = TRUE) |>
  filter(!is.na(ENTREZID))

res_df <- res_df |>
  inner_join(id_map, by = c("ensembl" = "ENSEMBL"))
message("After ID mapping: ", nrow(res_df), " genes")

# ---- ORA input: significantly upregulated genes -------------------------
sig_up   <- res_df |> filter(padj < 0.05 & log2FoldChange >  1)
sig_down <- res_df |> filter(padj < 0.05 & log2FoldChange < -1)
message("Upregulated: ", nrow(sig_up), "  Downregulated: ", nrow(sig_down))

# combine both directions for ORA (use |LFC| > 1 + padj < 0.05)
sig_genes <- res_df |> filter(padj < 0.05 & abs(log2FoldChange) > 1)
universe <- res_df$ENTREZID

# ---- GSEA input: ranked list by sign(LFC) * -log10(p) -------------------
rnk <- res_df |>
  filter(!is.na(pvalue) & pvalue > 0) |>
  mutate(stat = sign(log2FoldChange) * -log10(pvalue)) |>
  arrange(desc(stat))
rank_vec <- rnk$stat
names(rank_vec) <- rnk$ENTREZID
rank_vec <- rank_vec[!duplicated(names(rank_vec))]

# ---- 1. GO BP over-representation ---------------------------------------
message("GO BP over-representation ...")
ego <- enrichGO(
  gene          = sig_genes$ENTREZID,
  universe      = universe,
  OrgDb         = org.Hs.eg.db,
  keyType       = "ENTREZID",
  ont           = "BP",
  pAdjustMethod = "BH",
  pvalueCutoff  = 0.05,
  qvalueCutoff  = 0.1,
  readable      = TRUE
)
ego <- simplify(ego, cutoff = 0.7, by = "p.adjust", select_fun = min)
message("GO BP terms passing: ", nrow(ego@result[ego@result$p.adjust < 0.05, ]))

# ---- figure 1: GO dotplot -----------------------------------------------
message("Figure 1: GO BP dotplot")
p1 <- dotplot(ego, showCategory = 15, label_format = 50) +
  scale_color_gradient(low = biof3_colors["red"], high = biof3_colors["blue"]) +
  labs(title = "GO biological process enrichment",
       subtitle = "Top 15 BP terms from |LFC|>1 & padj<0.05 DE genes",
       x = "Gene ratio") +
  theme_biof3() +
  theme(axis.text.y = element_text(size = 9))
save_plot(p1, "01-go-bp-dotplot.png", width = 10, height = 8)

# ---- figure 2: GO barplot by category (BP / MF / CC) --------------------
message("Figure 2: GO barplot")
ego_all <- list()
for (ont in c("BP", "MF", "CC")) {
  e <- enrichGO(
    gene          = sig_genes$ENTREZID, universe = universe,
    OrgDb         = org.Hs.eg.db, keyType = "ENTREZID",
    ont           = ont, pAdjustMethod = "BH",
    pvalueCutoff  = 0.05, qvalueCutoff = 0.1, readable = TRUE
  )
  if (!is.null(e) && nrow(e@result) > 0) {
    e_s <- tryCatch(simplify(e, cutoff = 0.7, by = "p.adjust", select_fun = min),
                    error = function(x) e)
    top <- head(e_s@result |> arrange(p.adjust), 8)
    top$ont <- ont
    ego_all[[ont]] <- top
  }
}
ego_bar <- bind_rows(ego_all) |>
  mutate(Description = str_trunc(Description, 55),
         neglogp = -log10(p.adjust),
         ont = factor(ont, levels = c("BP", "MF", "CC")))

p2 <- ggplot(ego_bar, aes(x = reorder(Description, neglogp), y = neglogp, fill = ont)) +
  geom_col(width = 0.75) +
  coord_flip() +
  facet_wrap(~ ont, scales = "free_y", ncol = 1) +
  scale_fill_manual(values = c(BP = biof3_colors[["green"]],
                               MF = biof3_colors[["blue"]],
                               CC = biof3_colors[["violet"]]),
                    guide = "none") +
  labs(title = "Top GO terms in BP / MF / CC",
       subtitle = "Top 8 terms per ontology, ranked by -log10(p.adjust)",
       x = NULL, y = "-log10(p.adjust)") +
  theme_biof3()
save_plot(p2, "02-go-bar-bpmfcc.png", width = 11, height = 10)

# ---- figure 3: GO emap plot --------------------------------------------
message("Figure 3: GO enrichment map")
ego_pw <- pairwise_termsim(ego)
p3 <- emapplot(ego_pw, showCategory = 25) +
  labs(title = "GO BP term similarity map",
       subtitle = "Terms sharing genes are connected; clusters suggest common biology") +
  theme_biof3() +
  theme(axis.text = element_blank(), axis.ticks = element_blank(),
        axis.title = element_blank(), axis.line = element_blank())
save_plot(p3, "03-go-emap.png", width = 11, height = 8)

# ---- figure 4: KEGG dotplot (ORA) ---------------------------------------
message("Figure 4: KEGG over-representation")
ekegg <- tryCatch(
  enrichKEGG(
    gene          = sig_genes$ENTREZID,
    organism      = "hsa",
    universe      = universe,
    pAdjustMethod = "BH",
    pvalueCutoff  = 0.05,
    qvalueCutoff  = 0.2
  ),
  error = function(e) {
    message("KEGG online lookup failed: ", conditionMessage(e))
    NULL
  }
)

if (!is.null(ekegg) && nrow(ekegg@result) > 0) {
  # make readable
  ekegg <- setReadable(ekegg, OrgDb = org.Hs.eg.db, keyType = "ENTREZID")
  p4 <- dotplot(ekegg, showCategory = 15, label_format = 50) +
    scale_color_gradient(low = biof3_colors["red"], high = biof3_colors["blue"]) +
    labs(title = "KEGG pathway enrichment",
         subtitle = "Top 15 pathways from |LFC|>1 & padj<0.05 DE genes",
         x = "Gene ratio") +
    theme_biof3()
  save_plot(p4, "04-kegg-dotplot.png", width = 10, height = 8)
} else {
  message("Skipping figure 4 (KEGG unavailable)")
}

# ---- figure 5: GSEA GO BP running-score plots (top 3 + top 3 bottom) ----
message("GSEA GO BP ...")
gsea_go <- gseGO(
  geneList      = rank_vec,
  OrgDb         = org.Hs.eg.db,
  ont           = "BP",
  keyType       = "ENTREZID",
  pAdjustMethod = "BH",
  pvalueCutoff  = 0.25,
  minGSSize     = 15,
  maxGSSize     = 500,
  verbose       = FALSE,
  seed          = TRUE
)
message("GSEA GO BP terms passing: ", nrow(gsea_go@result))

# Figure 5: waterfall of top NES terms (up & down)
gsea_df <- gsea_go@result |>
  arrange(desc(NES))
top_up <- head(gsea_df, 10)
top_down <- tail(gsea_df, 10)
gsea_water <- bind_rows(top_up, top_down) |>
  mutate(direction = ifelse(NES > 0, "Activated", "Suppressed"),
         Description = str_trunc(Description, 55))
p5 <- ggplot(gsea_water, aes(x = reorder(Description, NES), y = NES, fill = direction)) +
  geom_col(width = 0.75) +
  coord_flip() +
  scale_fill_manual(values = c(Activated  = biof3_colors[["red"]],
                               Suppressed = biof3_colors[["blue"]])) +
  labs(title = "GSEA (GO BP): top activated and suppressed terms",
       subtitle = "Ranked by NES (normalized enrichment score)",
       x = NULL, y = "NES", fill = NULL) +
  theme_biof3()
save_plot(p5, "05-gsea-waterfall.png", width = 11, height = 8)

# ---- figure 6: GSEA running-score for one term --------------------------
message("Figure 6: GSEA running-score plot")
if (nrow(gsea_go@result) > 0) {
  # pick the term with highest |NES|
  pick_id <- gsea_go@result$ID[which.max(abs(gsea_go@result$NES))]
  pick_desc <- gsea_go@result$Description[gsea_go@result$ID == pick_id]
  p6 <- gseaplot2(gsea_go, geneSetID = pick_id,
                  title = paste0(pick_id, ": ", pick_desc),
                  pvalue_table = TRUE,
                  color = biof3_colors[["red"]])
  ggsave(file.path(output_dir, "06-gsea-running-score.png"),
         p6, width = 10, height = 7, dpi = 300, bg = "white")
}

# ---- done ---------------------------------------------------------------
message("=== module bulk03 enrichment figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Output dir: ", normalizePath(output_dir))
message("Done.")

sessionInfo()
