#!/usr/bin/env Rscript
# =============================================================================
# BioF3 Bulk RNA-seq Module 05: Time course + LRT with fission data
# =============================================================================
#
# This companion script for bulk RNA-seq module 05:
#   1. Loads the Bioconductor fission package (S. pombe wild type vs atf21del
#      mutant, 6 time points: 0 / 15 / 30 / 60 / 120 / 180 min oxidative stress,
#      3 biological replicates each = 36 samples)
#   2. Runs DESeq2 with an interaction design: strain * minute
#   3. Uses LRT (likelihood ratio test) to find genes whose response to time
#      differs between strains
#   4. Produces six figures:
#      - Sample QC: PCA colored by strain / minute
#      - LRT p-value volcano (minute:strain interaction)
#      - Time trajectories for top LRT genes
#      - Cluster time trajectories (DEGreport-style)
#      - Contrast-specific DE genes at 30 min between strains
#      - Heatmap of top LRT genes with sample annotation
#
# Data source:
#   Bioconductor fission package (no external download)
#   Leong et al. 2014
#
# Usage:
#   Rscript scripts/bulk05_timecourse_sci.R
#
# Optional env vars:
#   BIOF3_OUTPUT_DIR=/path/to/static/img/tutorial/bulk05
# =============================================================================

options(stringsAsFactors = FALSE)
set.seed(42)

# ---- paths ---------------------------------------------------------------
args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/bulk05_timecourse_sci.R", mustWork = FALSE)
}
script_dir <- dirname(script_path)
project_root <- if (basename(script_dir) == "scripts" && basename(dirname(script_dir)) == "static") {
  dirname(dirname(script_dir))
} else {
  dirname(script_dir)
}
output_dir <- Sys.getenv(
  "BIOF3_OUTPUT_DIR",
  file.path(project_root, "static", "img", "tutorial", "bulk05")
)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

# ---- dependencies --------------------------------------------------------
suppressPackageStartupMessages({
  library(DESeq2)
  library(fission)
  library(ggplot2)
  library(dplyr)
  library(patchwork)
  library(pheatmap)
  library(RColorBrewer)
  library(ggrepel)
  library(stringr)
  library(tidyr)
})

biof3_colors <- c(
  green = "#0f766e", mint  = "#43d1ae", blue  = "#2563eb",
  amber = "#f59e0b", red   = "#dc2626", slate = "#475569",
  violet = "#7c3aed", pink = "#ec4899"
)

theme_biof3 <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text  = element_text(color = "#111827"),
      axis.title = element_text(color = "#111827", face = "bold"),
      plot.title = element_text(color = "#12342f", face = "bold", hjust = 0),
      plot.subtitle = element_text(color = "#526760", hjust = 0),
      legend.title = element_text(face = "bold"),
      legend.position = "right",
      strip.background = element_blank(),
      strip.text = element_text(color = "#12342f", face = "bold")
    )
}

save_plot <- function(p, name, width = 8, height = 5) {
  ggsave(file.path(output_dir, name), p, width = width, height = height, dpi = 300, bg = "white")
}

# ---- load data -----------------------------------------------------------
message("Loading fission ...")
data("fission")
fission$minute <- relevel(fission$minute, "0")
fission$strain <- relevel(fission$strain, "wt")
message("Samples: ", ncol(fission), "  Genes: ", nrow(fission))
print(table(strain = fission$strain, minute = fission$minute))

# ---- DESeq2 with interaction + LRT --------------------------------------
message("Building DESeqDataSet with strain * minute design ...")
dds <- DESeqDataSet(fission, design = ~ strain + minute + strain:minute)
dds <- dds[rowSums(counts(dds)) >= 10, ]

message("Running LRT (full: ~ strain + minute + strain:minute, reduced: ~ strain + minute) ...")
# LRT: tests whether any of the interaction terms is non-zero
# i.e. does the time-response depend on strain?
dds <- DESeq(dds, test = "LRT", reduced = ~ strain + minute)

res_lrt <- results(dds)
message("LRT results:")
summary(res_lrt)

# genes with significant interaction
sig_lrt <- which(!is.na(res_lrt$padj) & res_lrt$padj < 0.05)
message("LRT padj < 0.05: ", length(sig_lrt), " genes")

# ---- figure 1: PCA on VST -----------------------------------------------
message("Figure 1: PCA on VST")
vsd <- vst(dds, blind = FALSE)
pca <- plotPCA(vsd, intgroup = c("strain", "minute"), returnData = TRUE)
pct <- round(100 * attr(pca, "percentVar"))
p1 <- ggplot(pca, aes(x = PC1, y = PC2, color = minute, shape = strain)) +
  geom_point(size = 4, alpha = 0.9) +
  scale_color_brewer(palette = "YlOrRd") +
  labs(title = "PCA on VST: fission time course",
       subtitle = "Color = time point; shape = strain (wt vs mut). PC1 tracks time, PC2 separates strains.",
       x = paste0("PC1 (", pct[1], "%)"),
       y = paste0("PC2 (", pct[2], "%)"),
       color = "Minute", shape = "Strain") +
  theme_biof3()
save_plot(p1, "01-pca-timecourse.png", width = 9, height = 6)

# ---- figure 2: LRT p-value histogram + volcano --------------------------
message("Figure 2: LRT p-value distribution")
p_hist_df <- data.frame(pvalue = res_lrt$pvalue) |> filter(!is.na(pvalue))
p2a <- ggplot(p_hist_df, aes(x = pvalue)) +
  geom_histogram(breaks = seq(0, 1, 0.025),
                 fill = biof3_colors["green"], color = "white", linewidth = 0.2) +
  labs(title = "LRT raw p-value distribution",
       subtitle = "Peak at 0 = many genes with strain-specific time response",
       x = "Raw p-value", y = "Number of genes") +
  theme_biof3()

lrt_df <- data.frame(
  gene = rownames(res_lrt),
  baseMean = res_lrt$baseMean,
  stat     = res_lrt$stat,
  padj     = res_lrt$padj
) |> filter(!is.na(padj))

p2b <- ggplot(lrt_df, aes(x = log10(baseMean + 1), y = stat,
                          color = padj < 0.05)) +
  geom_point(size = 0.5, alpha = 0.5) +
  scale_color_manual(values = c(`TRUE` = biof3_colors[["red"]],
                                `FALSE` = "#9ca3af")) +
  labs(title = "LRT test statistic vs expression",
       subtitle = paste0(sum(lrt_df$padj < 0.05),
                         " genes with padj < 0.05 (interaction non-zero)"),
       x = "log10(baseMean + 1)", y = "LRT statistic",
       color = "padj < 0.05") +
  theme_biof3()
p2 <- (p2a | p2b) + plot_annotation(
  title = "LRT: which genes have strain-specific time response?",
  theme = theme(plot.title = element_text(face = "bold", color = "#12342f"))
)
save_plot(p2, "02-lrt-overview.png", width = 13, height = 5.5)

# ---- figure 3: time trajectories for top 6 LRT genes --------------------
message("Figure 3: top LRT gene trajectories")
top_ids <- head(rownames(res_lrt)[order(res_lrt$padj)], 6)
traj_list <- lapply(top_ids, function(g) {
  d <- plotCounts(dds, gene = g, intgroup = c("strain", "minute"), returnData = TRUE)
  d$gene <- g
  d$minute_num <- as.numeric(as.character(d$minute))
  d
})
traj_df <- bind_rows(traj_list)

p3 <- ggplot(traj_df, aes(x = minute_num, y = count, color = strain, group = strain)) +
  geom_point(size = 1.5, alpha = 0.7) +
  stat_summary(fun = mean, geom = "line", linewidth = 1) +
  facet_wrap(~ gene, scales = "free_y", ncol = 3) +
  scale_color_manual(values = c(wt = biof3_colors[["slate"]],
                                mut = biof3_colors[["red"]])) +
  scale_y_log10() +
  scale_x_continuous(breaks = c(0, 15, 30, 60, 120, 180)) +
  labs(title = "Top 6 LRT genes: time trajectories",
       subtitle = "Points = biological replicates; lines = mean trajectory per strain",
       x = "Minutes after stress", y = "Normalized counts (log10)",
       color = "Strain") +
  theme_biof3()
save_plot(p3, "03-top-trajectories.png", width = 12, height = 7)

# ---- figure 4: clustered trajectories of top 500 LRT genes --------------
message("Figure 4: z-score trajectories of top-LRT gene clusters")
top500 <- head(rownames(res_lrt)[order(res_lrt$padj)], 500)
# average counts per strain:minute combination
vsd_mat <- assay(vsd)[top500, ]
grp <- paste(colData(dds)$strain, colData(dds)$minute, sep = "_")
avg <- t(sapply(unique(grp), function(g) rowMeans(vsd_mat[, grp == g, drop = FALSE])))
rownames(avg) <- unique(grp)
# z-score per gene
z_mat <- scale(t(avg))  # genes x conditions
# hierarchical cluster genes into 4 groups
hc <- hclust(dist(z_mat))
cl <- cutree(hc, k = 4)
z_df <- as.data.frame(z_mat) |>
  tibble::rownames_to_column("gene") |>
  mutate(cluster = paste0("C", cl)) |>
  pivot_longer(cols = -c(gene, cluster), names_to = "group", values_to = "z")
z_df$strain <- sub("_.+$", "", z_df$group)
z_df$minute <- as.numeric(sub("^.+_", "", z_df$group))

p4 <- ggplot(z_df, aes(x = minute, y = z, color = strain, group = interaction(gene, strain))) +
  geom_line(alpha = 0.06) +
  stat_summary(aes(group = strain), fun = mean, geom = "line", linewidth = 1.2) +
  facet_wrap(~ cluster, ncol = 2, scales = "free_y") +
  scale_color_manual(values = c(wt = biof3_colors[["slate"]],
                                mut = biof3_colors[["red"]])) +
  scale_x_continuous(breaks = c(0, 15, 30, 60, 120, 180)) +
  labs(title = "Top 500 LRT genes clustered into 4 trajectory groups",
       subtitle = "Thin lines = individual genes (z-score of VST); thick lines = per-strain mean",
       x = "Minutes after stress", y = "z-score (VST)",
       color = "Strain") +
  theme_biof3()
save_plot(p4, "04-cluster-trajectories.png", width = 12, height = 8)

# ---- figure 5: contrast at 30 min between strains -----------------------
message("Figure 5: strain-specific DE at 30 minutes")
# We need a Wald-based contrast to get LFC. Refit with Wald.
dds_wald <- DESeqDataSet(fission, design = ~ strain + minute + strain:minute)
dds_wald <- dds_wald[rowSums(counts(dds_wald)) >= 10, ]
dds_wald <- DESeq(dds_wald)
# contrast: mut vs wt at minute=30 is the strain_mut_vs_wt effect +
# the strainmut.minute30 interaction term
res30 <- results(dds_wald, name = "strainmut.minute30")
message("Strain:time=30 interaction DE at padj<0.05: ",
        sum(!is.na(res30$padj) & res30$padj < 0.05))

res30_df <- as.data.frame(res30) |>
  tibble::rownames_to_column("gene") |>
  filter(!is.na(padj)) |>
  mutate(
    sig = case_when(
      padj < 0.05 & log2FoldChange >  1 ~ "Up",
      padj < 0.05 & log2FoldChange < -1 ~ "Down",
      TRUE ~ "NS"
    )
  )
label30 <- res30_df |>
  filter(sig != "NS") |>
  arrange(padj) |>
  head(10)

p5 <- ggplot(res30_df,
             aes(x = log2FoldChange, y = -log10(pmax(padj, 1e-30)), color = sig)) +
  geom_point(size = 0.6, alpha = 0.5) +
  geom_text_repel(data = label30, aes(label = gene), size = 3,
                  max.overlaps = 25, color = "#111827", seed = 42) +
  scale_color_manual(values = c(Up = biof3_colors[["red"]],
                                Down = biof3_colors[["blue"]],
                                NS = "#9ca3af")) +
  geom_vline(xintercept = c(-1, 1), linetype = "dashed", color = "grey50") +
  geom_hline(yintercept = -log10(0.05), linetype = "dashed", color = "grey50") +
  labs(title = "Strain x Time=30 min interaction volcano",
       subtitle = "Genes whose mut-vs-wt difference at 30 min departs from the time-0 baseline",
       x = "log2 fold change (interaction term at t=30)",
       y = "-log10(padj)", color = NULL) +
  theme_biof3()
save_plot(p5, "05-contrast-30min.png", width = 10, height = 6.5)

# ---- figure 6: heatmap of top 40 LRT genes ------------------------------
message("Figure 6: heatmap of top 40 LRT genes")
top40 <- head(rownames(res_lrt)[order(res_lrt$padj)], 40)
mat <- assay(vsd)[top40, ]
mat_z <- t(scale(t(mat)))
col_order <- order(colData(vsd)$strain, as.numeric(as.character(colData(vsd)$minute)))
mat_z <- mat_z[, col_order]

anno_col <- data.frame(
  strain = colData(vsd)$strain[col_order],
  minute = factor(colData(vsd)$minute[col_order],
                  levels = levels(colData(vsd)$minute)),
  row.names = colnames(mat_z)
)
anno_colors <- list(
  strain = c(wt = biof3_colors[["slate"]], mut = biof3_colors[["red"]]),
  minute = setNames(brewer.pal(6, "YlOrRd"), levels(colData(vsd)$minute))
)

png(file.path(output_dir, "06-heatmap-top-lrt.png"),
    width = 11, height = 8, units = "in", res = 300, bg = "white")
pheatmap(
  mat_z,
  color = colorRampPalette(c(biof3_colors[["blue"]], "white", biof3_colors[["red"]]))(100),
  breaks = seq(-3, 3, length.out = 101),
  cluster_cols = FALSE,
  annotation_col = anno_col,
  annotation_colors = anno_colors,
  show_colnames = FALSE, fontsize_row = 7,
  main = "Top 40 LRT genes (row z-score of VST)"
)
dev.off()

# ---- done ---------------------------------------------------------------
message("=== module bulk05 timecourse figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Output dir: ", normalizePath(output_dir))
message("Done.")

sessionInfo()
