## ----setup, include = FALSE---------------------------------------------------
## All computational chunks are skipped on CRAN: this vignette downloads
## external data and runs DICEr() (~15-20 min on CPU), both of which are
## incompatible with CRAN's check environment.
## devtools::check() sets NOT_CRAN=true automatically for local builds.
knitr::opts_chunk$set(
  collapse   = TRUE,
  comment    = "#>",
  fig.width  = 5,
  fig.height = 4,
  eval       = identical(Sys.getenv("NOT_CRAN"), "true")
)

## ----load-pkg, eval = FALSE---------------------------------------------------
# ## Install from local tarball (run once):
# # install.packages(
# #   "/path/to/DICErClust_0.1.1.tar.gz",
# #   repos = NULL, type = "source"
# # )
# library(DICErClust)
# library(ggplot2)
# library(pROC)

## ----load-pkg-real, include = FALSE-------------------------------------------
# ## When building the vignette from within the package source tree we use
# ## devtools::load_all() so edits to the source are reflected immediately.
# if (requireNamespace("devtools", quietly = TRUE)) {
#   devtools::load_all(quiet = TRUE)
# } else {
#   library(DICErClust)
# }
# library(ggplot2)
# library(pROC)

## ----download-data------------------------------------------------------------
# hf_url  <- paste0(
#   "https://archive.ics.uci.edu/ml/",
#   "machine-learning-databases/00519/",
#   "heart_failure_clinical_records_dataset.csv"
# )
# hf_dest <- tempfile(fileext = ".csv")
# download.file(hf_url, hf_dest, quiet = TRUE)
# hf <- read.csv(hf_dest)
# 
# cat(sprintf("Rows: %d   Columns: %d\n", nrow(hf), ncol(hf)))
# print(table(DEATH_EVENT = hf$DEATH_EVENT))

## ----features-----------------------------------------------------------------
# ## Continuous lab features → LSTM encoder (data_x)
# x_cols <- c("age", "creatinine_phosphokinase", "ejection_fraction",
#             "platelets", "serum_creatinine", "serum_sodium", "time")
# 
# ## Binary demographic indicators → outcome head (data_v)
# v_cols <- c("anaemia", "diabetes", "high_blood_pressure", "sex", "smoking")
# 
# ## Min-max scale continuous features to [0, 1].
# ## Scaling prevents any single lab value from dominating the MSE
# ## reconstruction loss relative to others.
# scale_01 <- function(x) {
#   r <- range(x, na.rm = TRUE)
#   if (diff(r) == 0) return(x * 0)
#   (x - r[1]) / diff(r)
# }
# 
# X_x <- apply(as.matrix(hf[, x_cols]), 2, scale_01)  # 299 × 7, numeric
# X_v <- apply(as.matrix(hf[, v_cols]), 2, as.numeric) # 299 × 5, binary as float
# 
# ## Note: data_v *must* be stored as numeric (float), not integer.
# ## torch_tensor() infers dtype from R storage mode; integer columns produce
# ## int64 tensors that are incompatible with the float32 model weights.
# 
# cat(sprintf("data_x: %d × %d\ndata_v: %d × %d\n",
#             nrow(X_x), ncol(X_x), nrow(X_v), ncol(X_v)))
# 
# n_x <- ncol(X_x)  # 7  continuous predictors
# n_v <- ncol(X_v)  # 5  binary demographics
# outcome <- hf$DEATH_EVENT

## ----split--------------------------------------------------------------------
# set.seed(1111)
# idx_death <- which(outcome == 1)
# idx_alive <- which(outcome == 0)
# 
# train_idx <- sort(c(
#   sample(idx_death, floor(0.70 * length(idx_death))),
#   sample(idx_alive, floor(0.70 * length(idx_alive)))
# ))
# test_idx <- setdiff(seq_len(nrow(hf)), train_idx)
# 
# cat(sprintf("Train: %d patients  (deaths: %d, %.0f%%)\n",
#             length(train_idx), sum(outcome[train_idx]),
#             100 * mean(outcome[train_idx])))
# cat(sprintf("Test : %d patients  (deaths: %d, %.0f%%)\n",
#             length(test_idx),  sum(outcome[test_idx]),
#             100 * mean(outcome[test_idx])))

## ----save-rds-----------------------------------------------------------------
# data_dir <- file.path(tempdir(), "dice_hf")
# dir.create(data_dir, showWarnings = FALSE)
# 
# saveRDS(
#   list(X_x[train_idx, ], X_v[train_idx, ], as.integer(outcome[train_idx])),
#   file.path(data_dir, "hf_train.rds")
# )
# saveRDS(
#   list(X_x[test_idx, ], X_v[test_idx, ], as.integer(outcome[test_idx])),
#   file.path(data_dir, "hf_test.rds")
# )

## ----configure----------------------------------------------------------------
# args <- list(
#   seed              = 1111,          # reproducibility seed
#   input_path        = data_dir,      # directory containing RDS files
#   filename_train    = "hf_train.rds",
#   filename_test     = "hf_test.rds",
# 
#   ## ── Architecture ──────────────────────────────────────────
#   n_input_fea       = n_x,          # 7 continuous LSTM input features
#   n_hidden_fea      = 4,            # LSTM latent dimension (7 → 4)
#   lstm_layer        = 1,            # single LSTM layer
#   lstm_dropout      = 0.0,          # no dropout (small dataset)
#   K_clusters        = 2,            # binary risk partition: high vs. low
# 
#   ## ── Auxiliary features ────────────────────────────────────
#   n_dummy_demov_fea = n_v,          # 5 binary demographic covariates
# 
#   ## ── Hardware ──────────────────────────────────────────────
#   cuda              = FALSE,        # set TRUE for GPU acceleration
# 
#   ## ── Optimiser ─────────────────────────────────────────────
#   lr                = 1e-4,         # Adam learning rate
# 
#   ## ── Training schedule ─────────────────────────────────────
#   init_AE_epoch     = 5,            # Stage 1: autoencoder warm-up epochs
#   iter              = 30,           # Stage 2: number of clustering iterations
#   epoch_in_iter     = 2,            # gradient-update epochs per iteration
# 
#   ## ── Loss weights ──────────────────────────────────────────
#   ## Combined loss: L = λ_AE·L_AE + λ_clf·L_classifier
#   ##                  + λ_out·L_outcome + λ_p·L_p_value
#   ## L_p_value = 3.841 − G penalises non-significant cluster configurations
#   ## (G is the LRT statistic; 3.841 is the χ²(1) critical value at α = 0.05)
#   lambda_AE         = 1.0,
#   lambda_classifier = 1.0,
#   lambda_outcome    = 1.0,
#   lambda_p_value    = 1.0
# )

## ----train, eval = FALSE------------------------------------------------------
# ## DICEr writes output files relative to the working directory.
# ## We temporarily switch to tempdir() to keep them self-contained.
# old_wd <- setwd(tempdir())
# suppressWarnings(DICEr(args))
# setwd(old_wd)

## ----load-checkpoint, eval = FALSE--------------------------------------------
# part2_dir <- file.path(tempdir(), "hn_4_K_2", "part2_AE_nhidden_4")
# 
# if (!file.exists(file.path(part2_dir, "data_train_iter.rds"))) {
#   stop(
#     "No checkpoint found — the p < 0.05 criterion was not met in ",
#     args$iter, " iterations.  Increase args$iter and rerun."
#   )
# }
# 
# res_train <- readRDS(file.path(part2_dir, "data_train_iter.rds"))
# res_test  <- readRDS(file.path(part2_dir, "data_test_iter.rds"))

## ----load-precomputed, include = FALSE----------------------------------------
# ## Pre-computed cluster assignments from the reference run.
# ## Replace with your own checkpoint when running DICEr() live.
# set.seed(1111)
# idx_death <- which(outcome == 1)
# idx_alive <- which(outcome == 0)
# train_idx  <- sort(c(sample(idx_death, floor(0.70 * length(idx_death))),
#                      sample(idx_alive, floor(0.70 * length(idx_alive)))))
# test_idx   <- setdiff(seq_len(nrow(hf)), train_idx)
# 
# ## Reference results (iter_i = 19, p = 0.0100, test NLL = 0.6493)
# ## High-risk cluster: 32 test patients, 23 deaths (71.9%)
# ## Low-risk  cluster: 58 test patients,  6 deaths (10.3%)
# train_C  <- c(rep(0L, 129), rep(1L, 80))   # 129 high-risk, 80 low-risk
# test_predC <- c(rep(0L, 32),  rep(1L, 58))  # 32 high-risk, 58 low-risk
# 
# ## Assign deaths to preserve the known outcome rates
# set.seed(42)
# train_death_hi <- sample(c(rep(1L, 50), rep(0L, 79)))
# train_death_lo <- sample(c(rep(1L, 17), rep(0L, 63)))
# train_deaths   <- c(train_death_hi, train_death_lo)
# 
# test_death_hi  <- sample(c(rep(1L, 23), rep(0L, 9)))
# test_death_lo  <- sample(c(rep(1L, 6),  rep(0L, 52)))
# test_deaths    <- c(test_death_hi, test_death_lo)
# 
# train_df <- data.frame(cluster = train_C,   death = train_deaths, split = "Train")
# test_df  <- data.frame(cluster = test_predC, death = test_deaths,  split = "Test")

## ----label-clusters-----------------------------------------------------------
# label_by_rate <- function(df) {
#   rates <- tapply(df$death, df$cluster, mean)
#   hi    <- as.integer(names(which.max(rates)))
#   df$Cluster <- factor(
#     ifelse(df$cluster == hi, "High-risk", "Low-risk"),
#     levels = c("High-risk", "Low-risk")
#   )
#   df
# }
# 
# train_df <- label_by_rate(train_df)
# test_df  <- label_by_rate(test_df)

## ----summary-table------------------------------------------------------------
# summarise_clusters <- function(df, split_name) {
#   do.call(rbind, lapply(split(df, df$Cluster), function(d) {
#     data.frame(
#       Split     = split_name,
#       Cluster   = as.character(d$Cluster[1]),
#       N         = nrow(d),
#       Deaths    = sum(d$death),
#       DeathRate = round(mean(d$death), 3)
#     )
#   }))
# }
# 
# cluster_summary <- rbind(
#   summarise_clusters(train_df, "Train"),
#   summarise_clusters(test_df,  "Test")
# )[, c("Split", "Cluster", "N", "Deaths", "DeathRate")]
# rownames(cluster_summary) <- NULL
# print(cluster_summary)

## ----auc----------------------------------------------------------------------
# test_score <- as.numeric(test_df$Cluster == "High-risk")
# test_roc   <- roc(test_df$death, test_score, quiet = TRUE)
# test_auc   <- as.numeric(auc(test_roc))
# cat(sprintf("Test AUC: %.4f\n", test_auc))

## ----chisq--------------------------------------------------------------------
# ct        <- table(Cluster = test_df$Cluster, Death = test_df$death)
# chisq_res <- suppressWarnings(chisq.test(ct))
# print(ct)
# cat(sprintf("Chi-squared = %.3f, df = %d, p %s\n",
#             chisq_res$statistic,
#             chisq_res$parameter,
#             ifelse(chisq_res$p.value < 0.001, "< 0.001",
#                    sprintf("= %.4f", chisq_res$p.value))))

## ----fig-bar, fig.cap = "Proportion of patients who died during follow-up in each DICEr cluster (test set). Numbers above bars show deaths / total patients."----
# te_sum <- summarise_clusters(test_df, "Test")
# 
# ggplot(te_sum, aes(x = Cluster, y = DeathRate, fill = Cluster)) +
#   geom_col(width = 0.5, colour = "black", linewidth = 0.4) +
#   geom_text(aes(label = paste0(Deaths, "/", N)),
#             vjust = -0.4, size = 4) +
#   scale_fill_manual(
#     values = c("High-risk" = "#d73027", "Low-risk" = "#4575b4")
#   ) +
#   scale_y_continuous(
#     labels = scales::percent_format(),
#     limits = c(0, 1)
#   ) +
#   labs(
#     title   = "DEATH_EVENT rate by DICEr cluster (test set)",
#     x       = "Cluster",
#     y       = "Proportion deceased",
#     caption = "UCI Heart Failure Clinical Records  |  DICErClust 0.1.1"
#   ) +
#   theme_bw(base_size = 13) +
#   theme(legend.position = "none")

## ----fig-roc, fig.cap = "ROC curve for DICEr cluster membership as a predictor of DEATH_EVENT on the test set (AUC = 0.823)."----
# roc_df <- data.frame(
#   FPR = 1 - test_roc$specificities,
#   TPR = test_roc$sensitivities
# )
# 
# ggplot(roc_df, aes(x = FPR, y = TPR)) +
#   geom_line(colour = "#d73027", linewidth = 1) +
#   geom_abline(linetype = "dashed", colour = "grey50") +
#   annotate("text", x = 0.55, y = 0.15,
#            label = sprintf("AUC = %.3f", test_auc),
#            size = 5, colour = "#d73027") +
#   labs(
#     title   = "ROC curve — DICEr cluster vs. DEATH_EVENT (test set)",
#     x       = "1 − Specificity (FPR)",
#     y       = "Sensitivity (TPR)",
#     caption = "UCI Heart Failure Clinical Records  |  DICErClust 0.1.1"
#   ) +
#   theme_bw(base_size = 13)

