#!/usr/bin/env Rscript
# =============================================================================
# BioF3 Proteomics Module 02: DEP full workflow on UbiLength
# =============================================================================
#
# This companion script for the proteomics module 02:
#   1. Loads the DEP::UbiLength dataset (HeLa ubiquitination, 4 conditions x 3 reps)
#   2. Runs the standard DEP pipeline: make_unique -> make_se -> filter ->
#      normalize_vsn -> impute(MinProb) -> test_diff(limma) -> add_rejections
#   3. Produces six figures:
#      - Missing value heatmap (before imputation)
#      - Normalization effect (before vs after VSN)
#      - PCA on normalized + imputed data
#      - Sample correlation heatmap
#      - Volcano plot (Ubi6 vs Ctrl)
#      - Number of DE proteins per contrast
#
# Data source:
#   DEP::UbiLength (Bioconductor, no external download)
#   Zhang et al. 2017, Mol Cell 65(5):941-955
#
# Usage:
#   Rscript scripts/proteomics/prot02_dep_sci.R
#
# Dependencies:
#   DEP, ggplot2, dplyr, patchwork, pheatmap, RColorBrewer, ComplexHeatmap
# =============================================================================

options(stringsAsFactors = FALSE)
set.seed(42)

# ---- paths ---------------------------------------------------------------
args <- commandArgs(trailingOnly = FALSE)
file_arg <- grep("^--file=", args, value = TRUE)
script_path <- if (length(file_arg) > 0) {
  normalizePath(sub("^--file=", "", file_arg[[1]]), mustWork = TRUE)
} else {
  normalizePath("scripts/proteomics/prot02_dep_sci.R", mustWork = FALSE)
}
script_dir <- dirname(script_path)
project_root <- dirname(dirname(script_dir))
output_dir <- Sys.getenv(
  "BIOF3_OUTPUT_DIR",
  file.path(project_root, "static", "img", "tutorial", "proteomics", "prot02")
)
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)

# ---- dependencies --------------------------------------------------------
suppressPackageStartupMessages({
  library(DEP)
  library(ggplot2)
  library(dplyr)
  library(patchwork)
  library(pheatmap)
  library(RColorBrewer)
  library(SummarizedExperiment)
})

biof3_colors <- c(
  green = "#0f766e", mint  = "#43d1ae", blue  = "#2563eb",
  amber = "#f59e0b", red   = "#dc2626", slate = "#475569",
  violet = "#7c3aed", pink = "#ec4899"
)

theme_biof3 <- function(base_size = 12) {
  theme_classic(base_size = base_size) +
    theme(
      axis.text  = element_text(color = "#111827"),
      axis.title = element_text(color = "#111827", face = "bold"),
      plot.title = element_text(color = "#12342f", face = "bold", hjust = 0),
      plot.subtitle = element_text(color = "#526760", hjust = 0),
      legend.title = element_text(face = "bold"),
      legend.position = "right",
      strip.background = element_blank(),
      strip.text = element_text(color = "#12342f", face = "bold")
    )
}

save_plot <- function(p, name, width = 8, height = 5) {
  ggsave(file.path(output_dir, name), p, width = width, height = height, dpi = 300, bg = "white")
}

# ---- load + preprocess ---------------------------------------------------
message("Loading UbiLength data ...")
data(UbiLength)
data(UbiLength_ExpDesign)

message("Proteins: ", nrow(UbiLength), "  Samples: ", nrow(UbiLength_ExpDesign))
print(UbiLength_ExpDesign)

# Make unique names and build SummarizedExperiment
data_unique <- make_unique(UbiLength, "Gene.names", "Protein.IDs", delim = ";")
lfq_cols <- grep("LFQ.intensity.", colnames(data_unique))
data_se <- make_se(data_unique, lfq_cols, UbiLength_ExpDesign)
message("SE object: ", nrow(data_se), " proteins x ", ncol(data_se), " samples")

# ---- figure 1: missing value heatmap ------------------------------------
message("Figure 1: missing value heatmap")
png(file.path(output_dir, "01-missing-heatmap.png"),
    width = 9, height = 6, units = "in", res = 300, bg = "white")
plot_missval(data_se)
dev.off()

# ---- filter: keep proteins detected in >= 2 reps per condition ----------
message("Filtering ...")
data_filt <- filter_missval(data_se, thr = 0)
message("After filtering: ", nrow(data_filt), " proteins")

# ---- figure 2: normalization effect -------------------------------------
message("Figure 2: normalization (before vs after VSN)")
data_norm <- normalize_vsn(data_filt)

# extract intensities before and after
mat_before <- assay(data_filt)
mat_after  <- assay(data_norm)

df_before <- data.frame(
  value = as.vector(mat_before),
  sample = rep(colnames(mat_before), each = nrow(mat_before)),
  stage = "Before normalization"
)
df_after <- data.frame(
  value = as.vector(mat_after),
  sample = rep(colnames(mat_after), each = nrow(mat_after)),
  stage = "After VSN normalization"
)
df_norm <- bind_rows(df_before, df_after) |>
  filter(!is.na(value)) |>
  mutate(stage = factor(stage, levels = c("Before normalization", "After VSN normalization")))

p2 <- ggplot(df_norm, aes(x = sample, y = value, fill = stage)) +
  geom_boxplot(outlier.size = 0.3, width = 0.6) +
  facet_wrap(~ stage, scales = "free_y", ncol = 1) +
  scale_fill_manual(values = c(biof3_colors[["slate"]], biof3_colors[["green"]]), guide = "none") +
  labs(title = "Intensity distribution before vs after VSN",
       subtitle = "VSN stabilizes variance and aligns medians across samples",
       x = NULL, y = "Intensity") +
  theme_biof3() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, size = 9))
save_plot(p2, "02-normalization.png", width = 11, height = 8)

# ---- imputation ----------------------------------------------------------
message("Imputing missing values (MinProb) ...")
data_imp <- impute(data_norm, fun = "MinProb", q = 0.01)

# ---- figure 3: PCA -------------------------------------------------------
message("Figure 3: PCA")
pca_mat <- assay(data_imp)
pca_res <- prcomp(t(pca_mat), scale. = TRUE)
pct <- round(100 * summary(pca_res)$importance[2, 1:2])
pca_df <- data.frame(
  PC1 = pca_res$x[, 1],
  PC2 = pca_res$x[, 2],
  condition = colData(data_imp)$condition,
  replicate = colData(data_imp)$replicate,
  label = colData(data_imp)$label
)

p3 <- ggplot(pca_df, aes(x = PC1, y = PC2, color = condition)) +
  geom_point(size = 4, alpha = 0.9) +
  ggrepel::geom_text_repel(aes(label = label), size = 3, color = "#374151") +
  scale_color_manual(values = c(
    Ctrl = biof3_colors[["slate"]], Ubi1 = biof3_colors[["mint"]],
    Ubi4 = biof3_colors[["amber"]], Ubi6 = biof3_colors[["red"]]
  )) +
  labs(title = "PCA on normalized + imputed intensities",
       subtitle = paste0("PC1 (", pct[1], "%) separates treatment from control"),
       x = paste0("PC1 (", pct[1], "%)"),
       y = paste0("PC2 (", pct[2], "%)"),
       color = "Condition") +
  theme_biof3()
save_plot(p3, "03-pca.png", width = 9, height = 6)

# ---- figure 4: sample correlation heatmap --------------------------------
message("Figure 4: sample correlation heatmap")
cor_mat <- cor(pca_mat, use = "pairwise.complete.obs")
anno <- data.frame(
  condition = colData(data_imp)$condition,
  row.names = colnames(data_imp)
)
anno_colors <- list(
  condition = c(Ctrl = biof3_colors[["slate"]], Ubi1 = biof3_colors[["mint"]],
                Ubi4 = biof3_colors[["amber"]], Ubi6 = biof3_colors[["red"]])
)
png(file.path(output_dir, "04-correlation.png"),
    width = 8, height = 6.5, units = "in", res = 300, bg = "white")
pheatmap(
  cor_mat,
  annotation_col = anno,
  annotation_colors = anno_colors,
  color = colorRampPalette(c("white", biof3_colors[["green"]]))(100),
  breaks = seq(0.85, 1, length.out = 101),
  display_numbers = TRUE, number_format = "%.3f", fontsize_number = 8,
  main = "Sample-to-sample Pearson correlation"
)
dev.off()

# ---- differential analysis -----------------------------------------------
message("Running limma-based differential test ...")
data_diff <- test_diff(data_imp, type = "control", control = "Ctrl")
dep <- add_rejections(data_diff, alpha = 0.05, lfc = 1)

# ---- figure 5: volcano (Ubi6 vs Ctrl) -----------------------------------
message("Figure 5: volcano plot (Ubi6 vs Ctrl)")
res_df <- rowData(dep) |> as.data.frame()
# find the Ubi6_vs_Ctrl columns
lfc_col <- grep("Ubi6_vs_Ctrl_diff", colnames(res_df), value = TRUE)
p_col   <- grep("Ubi6_vs_Ctrl_p.adj", colnames(res_df), value = TRUE)

if (length(lfc_col) > 0 && length(p_col) > 0) {
  vol_df <- data.frame(
    gene = res_df$name,
    lfc  = res_df[[lfc_col[1]]],
    padj = res_df[[p_col[1]]]
  ) |> filter(!is.na(padj))
  vol_df$sig <- with(vol_df, case_when(
    padj < 0.05 & lfc >  1 ~ "Up",
    padj < 0.05 & lfc < -1 ~ "Down",
    TRUE ~ "NS"
  ))
  label_df <- vol_df |> filter(sig != "NS") |> arrange(padj) |> head(15)

  p5 <- ggplot(vol_df, aes(x = lfc, y = -log10(pmax(padj, 1e-16)), color = sig)) +
    geom_point(size = 1, alpha = 0.6) +
    ggrepel::geom_text_repel(data = label_df, aes(label = gene),
                             size = 3, max.overlaps = 20, color = "#111827") +
    scale_color_manual(values = c(Up = biof3_colors[["red"]],
                                  Down = biof3_colors[["blue"]],
                                  NS = "#9ca3af")) +
    geom_vline(xintercept = c(-1, 1), linetype = "dashed", color = "grey50") +
    geom_hline(yintercept = -log10(0.05), linetype = "dashed", color = "grey50") +
    labs(title = "Volcano: Ubi6 vs Ctrl",
         subtitle = paste0("Up: ", sum(vol_df$sig == "Up"),
                           "  Down: ", sum(vol_df$sig == "Down"),
                           "  (padj<0.05, |LFC|>1)"),
         x = "log2 fold change", y = "-log10(padj)", color = NULL) +
    theme_biof3()
  save_plot(p5, "05-volcano.png", width = 10, height = 7)
}

# ---- figure 6: DE protein counts per contrast ---------------------------
message("Figure 6: DE protein counts per contrast")
sig_cols <- grep("_significant$", colnames(res_df), value = TRUE)
if (length(sig_cols) > 0) {
  counts_df <- data.frame(
    contrast = sub("_significant$", "", sig_cols),
    n_sig = sapply(sig_cols, function(col) sum(res_df[[col]], na.rm = TRUE))
  ) |>
    mutate(contrast = gsub("_vs_", " vs ", contrast))

  p6 <- ggplot(counts_df, aes(x = reorder(contrast, n_sig), y = n_sig, fill = n_sig)) +
    geom_col(width = 0.6) +
    geom_text(aes(label = n_sig), hjust = -0.2, size = 4) +
    coord_flip() +
    scale_fill_gradient(low = biof3_colors[["slate"]], high = biof3_colors[["red"]], guide = "none") +
    labs(title = "Significant DE proteins per contrast",
         subtitle = "padj < 0.05 & |LFC| > 1; all vs Ctrl",
         x = NULL, y = "Number of DE proteins") +
    theme_biof3()
  save_plot(p6, "06-de-counts.png", width = 9, height = 5)
}

# ---- done ---------------------------------------------------------------
message("=== proteomics prot02 DEP figures ===")
print(list.files(output_dir, pattern = "\\.png$", full.names = FALSE))
message("Output dir: ", normalizePath(output_dir))
message("Done.")

sessionInfo()
