#!/usr/bin/env Rscript
# =============================================================================
# BioF3 Proteomics Module 04: Visualization — heatmap, volcano, network
# =============================================================================
#
# Produces 6 paper-ready figures from the DEP UbiLength results:
#   1. Enhanced volcano (all contrasts overlay)
#   2. Heatmap of top 30 DE proteins (z-score)
#   3. Protein intensity profiles (top 6 DE proteins)
#   4. Upset plot of DE protein overlap across contrasts
#   5. LFC comparison scatter (Ubi4 vs Ubi6)
#   6. Functional category bar (top GO terms colored by direction)
#
# Usage:
#   Rscript scripts/proteomics/prot04_visualization_sci.R
# =============================================================================

options(stringsAsFactors = FALSE)
set.seed(42)

args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/proteomics/prot04_visualization_sci.R", mustWork = FALSE)
}
script_dir <- dirname(script_path)
project_root <- dirname(dirname(script_dir))
output_dir <- Sys.getenv(
  "BIOF3_OUTPUT_DIR",
  file.path(project_root, "static", "img", "tutorial", "proteomics", "prot04")
)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

suppressPackageStartupMessages({
  library(DEP)
  library(ggplot2)
  library(dplyr)
  library(tidyr)
  library(patchwork)
  library(pheatmap)
  library(RColorBrewer)
  library(ggrepel)
  library(SummarizedExperiment)
})

biof3_colors <- c(green="#0f766e", mint="#43d1ae", blue="#2563eb",
                  amber="#f59e0b", red="#dc2626", slate="#475569",
                  violet="#7c3aed", pink="#ec4899")
theme_biof3 <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(axis.text = element_text(color = "#111827"),
          axis.title = element_text(color = "#111827", face = "bold"),
          plot.title = element_text(color = "#12342f", face = "bold", hjust = 0),
          plot.subtitle = element_text(color = "#526760", hjust = 0),
          legend.title = element_text(face = "bold"),
          strip.background = element_blank(),
          strip.text = element_text(color = "#12342f", face = "bold"))
}
save_plot <- function(p, name, width = 8, height = 5) {
  ggsave(file.path(output_dir, name), p, width = width, height = height, dpi = 300, bg = "white")
}

# ---- DEP pipeline --------------------------------------------------------
message("Running DEP pipeline ...")
data(UbiLength); data(UbiLength_ExpDesign)
data_unique <- make_unique(UbiLength, "Gene.names", "Protein.IDs", delim = ";")
lfq_cols <- grep("LFQ.intensity.", colnames(data_unique))
data_se <- make_se(data_unique, lfq_cols, UbiLength_ExpDesign)
data_filt <- filter_missval(data_se, thr = 0)
data_norm <- normalize_vsn(data_filt)
data_imp <- impute(data_norm, fun = "MinProb", q = 0.01)
data_diff <- test_diff(data_imp, type = "control", control = "Ctrl")
dep <- add_rejections(data_diff, alpha = 0.05, lfc = 1)
res_df <- as.data.frame(rowData(dep))

# ---- figure 1: multi-contrast volcano -----------------------------------
message("Figure 1: multi-contrast volcano")
contrasts <- c("Ubi1_vs_Ctrl", "Ubi4_vs_Ctrl", "Ubi6_vs_Ctrl")
vol_list <- lapply(contrasts, function(ct) {
  lfc_col <- paste0(ct, "_diff")
  p_col <- paste0(ct, "_p.adj")
  if (!lfc_col %in% colnames(res_df)) return(NULL)
  data.frame(gene = res_df$name, lfc = res_df[[lfc_col]], padj = res_df[[p_col]],
             contrast = ct, stringsAsFactors = FALSE) |> filter(!is.na(padj))
})
vol_all <- bind_rows(vol_list) |>
  mutate(sig = case_when(padj < 0.05 & lfc > 1 ~ "Up", padj < 0.05 & lfc < -1 ~ "Down", TRUE ~ "NS"),
         contrast = gsub("_vs_Ctrl", "", contrast))

p1 <- ggplot(vol_all, aes(x = lfc, y = -log10(pmax(padj, 1e-16)), color = sig)) +
  geom_point(size = 0.6, alpha = 0.5) +
  facet_wrap(~ contrast, ncol = 3) +
  scale_color_manual(values = c(Up = biof3_colors[["red"]], Down = biof3_colors[["blue"]], NS = "#9ca3af")) +
  geom_vline(xintercept = c(-1, 1), linetype = "dashed", color = "grey50") +
  geom_hline(yintercept = -log10(0.05), linetype = "dashed", color = "grey50") +
  labs(title = "Volcano plots: all contrasts vs Ctrl",
       subtitle = "Ubi chain length increases left to right; more DE proteins with longer chains",
       x = "log2 fold change", y = "-log10(padj)", color = NULL) +
  theme_biof3()
save_plot(p1, "01-volcano-multi.png", width = 14, height = 5)

# ---- figure 2: heatmap of top 30 DE proteins ----------------------------
message("Figure 2: heatmap top 30 DE")
sig_col <- grep("Ubi6_vs_Ctrl_significant", colnames(res_df), value = TRUE)[1]
padj_col <- grep("Ubi6_vs_Ctrl_p.adj", colnames(res_df), value = TRUE)[1]
top30 <- res_df |>
  filter(!is.na(.data[[sig_col]]) & .data[[sig_col]] == TRUE) |>
  arrange(.data[[padj_col]]) |>
  head(30)
mat <- assay(data_imp)[top30$name, ]
rownames(mat) <- top30$name
mat_z <- t(scale(t(mat)))
anno_col <- data.frame(condition = colData(data_imp)$condition, row.names = colnames(data_imp))
anno_colors <- list(condition = c(Ctrl = biof3_colors[["slate"]], Ubi1 = biof3_colors[["mint"]],
                                  Ubi4 = biof3_colors[["amber"]], Ubi6 = biof3_colors[["red"]]))
png(file.path(output_dir, "02-heatmap-top30.png"), width = 9, height = 8, units = "in", res = 300, bg = "white")
pheatmap(mat_z, color = colorRampPalette(c(biof3_colors[["blue"]], "white", biof3_colors[["red"]]))(100),
         breaks = seq(-3, 3, length.out = 101), annotation_col = anno_col,
         annotation_colors = anno_colors, show_colnames = FALSE, fontsize_row = 8,
         main = "Top 30 DE proteins (Ubi6 vs Ctrl, row z-score)")
dev.off()

# ---- figure 3: protein intensity profiles --------------------------------
message("Figure 3: intensity profiles")
top6 <- head(top30$name, 6)
prof_mat <- assay(data_imp)[top30$name[1:6], ]
rownames(prof_mat) <- top6
prof_df <- as.data.frame(t(prof_mat)) |>
  tibble::rownames_to_column("sample") |>
  mutate(condition = colData(data_imp)$condition) |>
  pivot_longer(cols = all_of(top6), names_to = "protein", values_to = "intensity")

p3 <- ggplot(prof_df, aes(x = condition, y = intensity, fill = condition)) +
  geom_boxplot(width = 0.5, outlier.shape = NA, alpha = 0.8) +
  geom_jitter(width = 0.1, size = 1.5, alpha = 0.7) +
  facet_wrap(~ protein, scales = "free_y", ncol = 3) +
  scale_fill_manual(values = c(Ctrl = biof3_colors[["slate"]], Ubi1 = biof3_colors[["mint"]],
                               Ubi4 = biof3_colors[["amber"]], Ubi6 = biof3_colors[["red"]]), guide = "none") +
  labs(title = "Top 6 DE protein intensity profiles",
       subtitle = "VSN-normalized + imputed intensities across conditions",
       x = NULL, y = "Normalized intensity") +
  theme_biof3()
save_plot(p3, "03-intensity-profiles.png", width = 11, height = 7)

# ---- figure 4: upset-style overlap across contrasts ----------------------
message("Figure 4: DE overlap across contrasts")
sig_cols <- grep("_significant$", colnames(res_df), value = TRUE)
de_sets <- lapply(sig_cols, function(col) res_df$name[!is.na(res_df[[col]]) & res_df[[col]] == TRUE])
names(de_sets) <- gsub("_vs_Ctrl_significant", "", sig_cols)

all_de <- unique(unlist(de_sets))
in_mat <- sapply(de_sets, function(x) all_de %in% x)
pattern <- apply(in_mat, 1, function(x) paste(names(de_sets)[x], collapse = " + "))
counts <- sort(table(pattern), decreasing = TRUE)
pat_df <- data.frame(pattern = names(counts), n = as.integer(counts))

p4 <- ggplot(pat_df, aes(x = reorder(pattern, n), y = n, fill = n)) +
  geom_col(width = 0.7) + geom_text(aes(label = n), hjust = -0.1, size = 3.5) +
  coord_flip() +
  scale_fill_gradient(low = biof3_colors[["slate"]], high = biof3_colors[["red"]], guide = "none") +
  labs(title = "DE protein overlap across contrasts",
       subtitle = "Which proteins are shared by multiple Ubi-length comparisons",
       x = NULL, y = "Number of proteins") +
  theme_biof3()
save_plot(p4, "04-de-overlap.png", width = 10, height = 5)

# ---- figure 5: LFC scatter Ubi4 vs Ubi6 ---------------------------------
message("Figure 5: LFC scatter")
lfc4 <- res_df[["Ubi4_vs_Ctrl_diff"]]
lfc6 <- res_df[["Ubi6_vs_Ctrl_diff"]]
scat_df <- data.frame(gene = res_df$name, Ubi4 = lfc4, Ubi6 = lfc6) |> filter(!is.na(Ubi4) & !is.na(Ubi6))
r_val <- round(cor(scat_df$Ubi4, scat_df$Ubi6, use = "complete.obs"), 3)

p5 <- ggplot(scat_df, aes(x = Ubi4, y = Ubi6)) +
  geom_point(size = 0.6, alpha = 0.4, color = biof3_colors["violet"]) +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = biof3_colors["red"]) +
  labs(title = paste0("LFC correlation: Ubi4 vs Ubi6 (r = ", r_val, ")"),
       subtitle = "Points above the diagonal = stronger effect at longer Ubi chain",
       x = "log2FC (Ubi4 vs Ctrl)", y = "log2FC (Ubi6 vs Ctrl)") +
  theme_biof3()
save_plot(p5, "05-lfc-scatter.png", width = 8, height = 7)

# ---- figure 6: top DE protein bar by direction ---------------------------
message("Figure 6: top DE proteins bar")
lfc_col6 <- "Ubi6_vs_Ctrl_diff"
p_col6 <- "Ubi6_vs_Ctrl_p.adj"
bar_df <- res_df |>
  filter(!is.na(res_df[[sig_col]]) & res_df[[sig_col]] == TRUE) |>
  arrange(desc(abs(res_df[[lfc_col6]][!is.na(res_df[[sig_col]]) & res_df[[sig_col]] == TRUE])))

bar_df <- res_df |>
  filter(!is.na(.data[[sig_col]]) & .data[[sig_col]] == TRUE) |>
  mutate(lfc = .data[[lfc_col6]], direction = ifelse(lfc > 0, "Up", "Down")) |>
  arrange(desc(abs(lfc))) |> head(20)

p6 <- ggplot(bar_df, aes(x = reorder(name, lfc), y = lfc, fill = direction)) +
  geom_col(width = 0.7) + coord_flip() +
  scale_fill_manual(values = c(Up = biof3_colors[["red"]], Down = biof3_colors[["blue"]])) +
  labs(title = "Top 20 DE proteins by |LFC| (Ubi6 vs Ctrl)",
       subtitle = "Sorted by absolute fold change; color = direction",
       x = NULL, y = "log2 fold change", fill = NULL) +
  theme_biof3()
save_plot(p6, "06-top-de-bar.png", width = 10, height = 7)

message("=== proteomics prot04 visualization figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Done.")
sessionInfo()
