###############################################################################
##  Shiny App: Personalized Multi-State Joint Model
##
##  WORKFLOW:
##    1. Upload longitudinal_biomarkers.csv + survival_events.csv
##    2. Auto-merge by patient_id
##    3. Explore data: trajectories, transitions, biomarker profiles
##    4. Fit transition-specific joint models (GAM-Cox with spline surfaces)
##    5. Select a patient → see personalized risk of NEXT event
##       "Given Patient X currently has CVD, what is their risk of
##        developing CKD or Diabetes in the next 1-5 years?"
##
##  install.packages(c("shiny","shinydashboard","shinyWidgets",
##    "nlme","survival","mgcv","ggplot2","viridis","plotly",
##    "dplyr","tidyr","DT","gridExtra"))
###############################################################################

library(shiny)
library(shinydashboard)
library(shinyWidgets)
library(nlme)
library(survival)
library(mgcv)
library(ggplot2)
library(viridis)
library(plotly)
library(dplyr)
library(tidyr)
library(DT)
library(gridExtra)
if (!requireNamespace("ggrepel", quietly = TRUE)) stop("Package 'ggrepel' is required for the Shiny app. Please install it with: install.packages('ggrepel')")
library(ggrepel)


# ═══════════════════════════════════════════════════════════════
#  UI
# ═══════════════════════════════════════════════════════════════

ui <- dashboardPage(
  skin = "black",

  dashboardHeader(
    title = span(icon("heartbeat"),
                 " Multi-State Joint Model — Personalized Risk",
                 style = "font-size:13px; font-weight:600;"),
    titleWidth = 420
  ),

  dashboardSidebar(
    width = 260,
    sidebarMenu(
      id = "tabs",
      menuItem("1. Upload & Merge", tabName = "upload", icon = icon("upload")),
      menuItem("2. Data Explorer", tabName = "explore", icon = icon("search")),
      menuItem("3. Fit Joint Models", tabName = "fit", icon = icon("cogs")),
      menuItem("4. Personalized Risk", tabName = "predict", icon = icon("user-md")),
      menuItem("5. Population Surfaces", tabName = "surfaces", icon = icon("mountain")),
      menuItem("Guide", tabName = "guide", icon = icon("book-open"))
    ),
    tags$hr(style = "border-color:#444; margin:8px 15px;"),
    tags$div(style = "padding:8px 15px; color:#888; font-size:10px;",
      HTML("Semi-Parametric Association<br>
            Surfaces for Joint Modeling<br>
            Bhattacharjee (2025)<br><br>
            <b>Packages:</b> jmBIG (2024),<br>
            JMbdirect (2025)<br>
            Bhattacharjee et al."))
  ),

  dashboardBody(
    tags$head(tags$style(HTML("
      @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@300;400;600;700&family=Fira+Code:wght@400&display=swap');
      body { font-family: 'Source Sans Pro', sans-serif; }
      .content-wrapper { background: #f4f6fb; }
      .skin-black .main-header .logo { background: #0d1b2a; }
      .skin-black .main-header .navbar { background: #1b2838; }
      .skin-black .main-sidebar { background: #0d1b2a; }

      .box { border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.06); }
      .box.box-primary { border-top: 3px solid #1b4965; }
      .box.box-success { border-top: 3px solid #2a9d8f; }
      .box.box-warning { border-top: 3px solid #e76f51; }
      .box.box-info    { border-top: 3px solid #457b9d; }
      .box.box-danger  { border-top: 3px solid #c1121f; }

      .status-ok { color: #2a9d8f; font-weight: 700; }
      .status-fail { color: #c1121f; font-weight: 700; }

      .merge-summary {
        background: linear-gradient(135deg, #f0f7ff, #e8f4f8);
        border-left: 5px solid #1b4965; border-radius: 0 8px 8px 0;
        padding: 16px 20px; margin: 10px 0;
      }

      .risk-card {
        background: white; border-radius: 12px; padding: 20px; text-align: center;
        box-shadow: 0 3px 15px rgba(0,0,0,0.08); margin-bottom: 14px;
      }
      .risk-card .risk-value { font-size: 36px; font-weight: 700; }
      .risk-card .risk-label { font-size: 12px; color: #888; margin-top: 4px; text-transform: uppercase; }
      .risk-low { color: #2a9d8f; }
      .risk-med { color: #e9c46a; }
      .risk-high { color: #e76f51; }
      .risk-vhigh { color: #c1121f; }

      .patient-card {
        background: white; border-radius: 10px; padding: 16px;
        box-shadow: 0 2px 8px rgba(0,0,0,0.06); margin-bottom: 10px;
        border-left: 5px solid #457b9d;
      }

      .interp-panel {
        background: linear-gradient(135deg, #f0f7ff, #e8f4f8);
        border-left: 5px solid #2a9d8f; border-radius: 0 8px 8px 0;
        padding: 16px 20px; margin: 10px 0; line-height: 1.6;
      }
      .interp-panel h4 { color: #1b4965; margin-top: 0; font-weight: 700; }

      .trans-arrow {
        display: inline-block; padding: 4px 14px; border-radius: 6px;
        font-weight: 600; font-size: 13px; margin: 3px; color: white;
      }
      .trans-ckd { background: #457b9d; }
      .trans-cvd { background: #e76f51; }
      .trans-dm  { background: #e9c46a; color: #333; }
      .trans-ar  { background: #2a9d8f; }

      .code-block {
        background: #1b2838; color: #a8dadc; border-radius: 6px;
        padding: 12px 16px; font-family: 'Fira Code', monospace;
        font-size: 12px; margin: 8px 0; overflow-x: auto;
      }

      /* ── Styled model results table ── */
      .model-results-table {
        width: 100%; border-collapse: collapse; font-size: 14px;
        border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.06);
      }
      .model-results-table thead th {
        background: #1b4965; color: white; padding: 12px 14px;
        font-weight: 700; font-size: 13px; text-align: left;
      }
      .model-results-table tbody td {
        padding: 10px 14px; border-bottom: 1px solid #eef2f7;
        font-size: 14px; font-weight: 500;
      }
      .model-results-table tbody tr:hover { background: #f0f7ff; }
      .model-results-table tbody tr:nth-child(even) { background: #fafbfd; }

      /* ── Status log styling ── */
      .status-log {
        background: #0d1b2a; color: #a8dadc; border-radius: 8px;
        padding: 14px 18px; font-family: 'Fira Code', monospace;
        font-size: 13px; line-height: 1.8; max-height: 300px; overflow-y: auto;
      }
      .status-log .log-ok { color: #2a9d8f; font-weight: 700; }
      .status-log .log-fail { color: #e76f51; font-weight: 700; }
      .status-log .log-warn { color: #e9c46a; font-weight: 700; }
      .status-log .log-method { color: #a8dadc; font-weight: 600; }

      /* ── LME Summary styling ── */
      .lme-card {
        background: white; border-radius: 8px; padding: 16px; margin-bottom: 12px;
        box-shadow: 0 2px 6px rgba(0,0,0,0.05); border-left: 4px solid #457b9d;
      }
      .lme-card h5 { color: #1b4965; font-weight: 700; margin: 0 0 8px 0; font-size: 16px; }
      .lme-card .lme-row { display: flex; justify-content: space-between; padding: 4px 0; }
      .lme-card .lme-label { color: #666; font-weight: 600; font-size: 13px; }
      .lme-card .lme-val { color: #1b4965; font-weight: 700; font-size: 14px; }

      /* ── Prediction status styled ── */
      .pred-status-panel {
        background: #0d1b2a; border-radius: 8px; padding: 12px 16px;
        max-height: 280px; overflow-y: auto; margin-top: 8px;
      }
      .pred-status-line {
        padding: 4px 0; font-family: 'Fira Code', monospace; font-size: 12px;
        border-bottom: 1px solid rgba(255,255,255,0.05);
      }
      .pred-ok { color: #2a9d8f; }
      .pred-fail { color: #e76f51; }
      .pred-warn { color: #e9c46a; }

      /* ── Fit progress styling ── */
      .fit-progress { font-family: 'Fira Code', monospace; font-size: 12px;
        background: #f8f9fa; border-radius: 6px; padding: 10px; }
    "))),

    tabItems(

      # ════════════════════════════════════════════════════
      #  TAB 1: UPLOAD & MERGE
      # ════════════════════════════════════════════════════
      tabItem(tabName = "upload",
        fluidRow(
          box(width = 5, status = "primary", solidHeader = TRUE,
              title = span(icon("upload"), " Upload Your Two CSV Files"),
            h4("Step 1: Longitudinal Biomarkers"),
            fileInput("file_long", NULL, accept = ".csv",
                      placeholder = "longitudinal_biomarkers.csv"),
            h4("Step 2: Survival Events"),
            fileInput("file_surv", NULL, accept = ".csv",
                      placeholder = "survival_events.csv"),
            hr(),
            actionButton("btn_merge", "Merge & Validate",
                         icon = icon("link"), class = "btn-primary btn-block",
                         style = "font-weight:600; font-size:15px; padding:12px;"),
            hr(),
            tags$div(style = "background:#f0f7ff; border-radius:8px; padding:14px; margin-top:8px;",
              h4(icon("flask"), " Or: Use Built-in Simulated Data",
                 style = "color:#1b4965; margin-top:0;"),
              p("Generate a realistic CKD/CVD/Diabetes multi-state dataset with
                 3 biomarkers (eGFR, BNP, HbA1c) and bidirectional transitions.",
                style = "font-size:12px; color:#555;"),
              numericInput("sim_n_patients", "Number of patients",
                           value = 500, min = 50, max = 50000, step = 50),
              tags$small(style="color:#888;", "Recommended: 200–2000 for quick demo. Up to 50,000 for stress testing."),
              actionButton("btn_simulate", "Generate & Load Simulated Data",
                           icon = icon("dice"), class = "btn-info btn-block",
                           style = "font-weight:600; font-size:14px; padding:10px;")
            ),
            hr(),
            uiOutput("merge_status")
          ),
          box(width = 7, status = "info", solidHeader = TRUE,
              title = span(icon("info-circle"), " Data Summary After Merge"),
            uiOutput("merge_summary"),
            hr(),
            h4(icon("exchange-alt"), " Transition Counts",
               style = "color:#1b4965; font-weight:700;"),
            uiOutput("transition_counts_table"),
            hr(),
            h4(icon("chart-bar"), " Event Burden Distribution",
               style = "color:#1b4965; font-weight:700;"),
            plotOutput("event_burden_plot", height = "250px")
          )
        )
      ),

      # ════════════════════════════════════════════════════
      #  TAB 2: DATA EXPLORER
      # ════════════════════════════════════════════════════
      tabItem(tabName = "explore",
        fluidRow(
          box(width = 4, status = "primary", solidHeader = TRUE,
              title = span(icon("filter"), " Filters"),
            uiOutput("ui_entry_filter"),
            uiOutput("ui_trajectory_filter"),
            sliderInput("explore_n_patients", "Show N patients", 5, 50, 15, step = 5)
          ),
          box(width = 8, status = "info", solidHeader = TRUE,
              title = span(icon("chart-line"), " Biomarker Trajectories"),
            tabsetPanel(
              tabPanel("eGFR", plotOutput("plot_egfr", height = "350px")),
              tabPanel("BNP", plotOutput("plot_bnp", height = "350px")),
              tabPanel("HbA1c", plotOutput("plot_hba1c", height = "350px")),
              tabPanel("All Markers", plotOutput("plot_all_markers", height = "400px"))
            )
          )
        ),
        fluidRow(
          box(width = 6, status = "success", solidHeader = TRUE,
              title = "Cumulative Incidence by Event Type",
            plotOutput("cumulative_incidence_plot", height = "320px")
          ),
          box(width = 6, status = "warning", solidHeader = TRUE,
              title = "Biomarker Distributions by Current State",
            plotOutput("biomarker_boxplot", height = "320px")
          )
        )
      ),

      # ════════════════════════════════════════════════════
      #  TAB 3: FIT JOINT MODELS
      # ════════════════════════════════════════════════════
      tabItem(tabName = "fit",
        fluidRow(
          box(width = 4, status = "primary", solidHeader = TRUE,
              title = span(icon("cogs"), " Model Configuration"),
            p(style = "font-size:14px; line-height:1.6;",
              "For each transition (e.g., CVD → CKD), a separate joint model is fitted:"),
            tags$ol(style = "font-size:13px; line-height:1.8;",
              tags$li(tags$b("Longitudinal submodel:"), " mixed-effects regression for each biomarker"),
              tags$li(tags$b("Survival submodel:"), " GAM-Cox with tensor-product spline surface")
            ),
            hr(),
            uiOutput("ui_select_transitions"),
            hr(),
            actionButton("btn_fit", "Fit Transition Models",
                         icon = icon("rocket"), class = "btn-primary btn-block",
                         style = "font-weight:600; font-size:15px; padding:12px;"),
            hr(),
            uiOutput("fit_progress_styled")
          ),
          box(width = 8, status = "info", solidHeader = TRUE,
              title = span(icon("table"), " Fitted Model Results"),
            uiOutput("fit_results_styled"),
            hr(),
            uiOutput("fit_interpretation"),
            hr(),
            h4(icon("chart-bar"), " Longitudinal Submodel Summaries",
               style = "color:#1b4965; font-weight:700;"),
            uiOutput("lme_summaries_styled")
          )
        )
      ),

      # ════════════════════════════════════════════════════
      #  TAB 4: PERSONALIZED RISK PREDICTION (5 METHODS)
      # ════════════════════════════════════════════════════
      tabItem(tabName = "predict",
        fluidRow(
          box(width = 4, status = "danger", solidHeader = TRUE,
              title = span(icon("user-md"), " Select Patient & Methods"),
            uiOutput("ui_patient_select"),
            hr(),
            uiOutput("patient_info_card"),
            hr(),
            sliderInput("pred_horizon", "Prediction horizon (years ahead)",
                        min = 0.5, max = 10, value = 3, step = 0.5),
            sliderInput("pred_landmark", "Landmark time (predict FROM this time)",
                        min = 0, max = 10, value = 0, step = 0.5),
            hr(),
            h4(icon("flask"), " Prediction Methods", style = "color:#1b4965;"),
            checkboxGroupInput("sel_methods", NULL,
              choices = c(
                "Two-Stage BLUP + GAM-Cox" = "twostage",
                "True JM (ML via JM pkg)" = "true_jm",
                "Landmark (LOCF)" = "landmark",
                "Landmarking 2.0 (BLUP+Slope)" = "landmark_blup",
                "Bayesian JM (JMbayes2)" = "bayesian_jm",
                "jmBIG (Big Data JM)" = "jmbig",
                "JMbdirect (Bidirectional)" = "jmbdirect"
              ),
              selected = c("twostage", "landmark", "landmark_blup")
            ),
            tags$small(style = "color:#888;",
              "Methods 2, 5-7 require optional packages (JM, JMbayes2, jmBIG, JMbdirect).
               Bhattacharjee et al. (2024, 2025)."),
            hr(),
            # ── Stratification controls ──
            h4(icon("layer-group"), " Stratified Analysis", style = "color:#1b4965;"),
            uiOutput("ui_stratify_by"),
            tags$small(style = "color:#888;",
              "Select a demographic variable to compare risk across subgroups.
               Population-level stratified predictions use the fitted model on subgroup data."),
            hr(),
            actionButton("btn_predict", "Compute All Predictions",
                         icon = icon("magic"), class = "btn-danger btn-block",
                         style = "font-weight:600; font-size:15px; padding:12px;"),
            hr(),
            uiOutput("predict_status_styled")
          ),
          box(width = 8, status = "warning", solidHeader = TRUE,
              title = span(icon("heartbeat"), " Multi-Method Dynamic Risk Prediction"),
            fluidRow(
              column(4, uiOutput("risk_card_1")),
              column(4, uiOutput("risk_card_2")),
              column(4, uiOutput("risk_card_3"))
            ),
            hr(),
            tabsetPanel(
              tabPanel("Method Comparison",
                plotOutput("method_comparison_plot", height = "420px")
              ),
              tabPanel("Risk Over Time (Best)",
                plotOutput("risk_over_time_plot", height = "380px")
              ),
              tabPanel("Biomarker Trajectories",
                plotOutput("patient_trajectory_plot", height = "380px")
              ),
              tabPanel("Risk Surface Position",
                plotlyOutput("patient_surface_position", height = "380px")
              ),
              tabPanel("Stratified Analysis",
                uiOutput("stratified_summary_panel"),
                hr(),
                plotOutput("stratified_risk_plot", height = "420px"),
                hr(),
                plotOutput("stratified_biomarker_plot", height = "350px")
              ),
              tabPanel("Demographic Profile",
                plotOutput("demographic_risk_forest", height = "450px"),
                hr(),
                uiOutput("demographic_table_panel")
              ),
              tabPanel("Method Summary",
                uiOutput("method_comparison_table"),
                hr(),
                uiOutput("method_details_panel")
              )
            ),
            hr(),
            uiOutput("prediction_interpretation")
          )
        )
      ),

      # ════════════════════════════════════════════════════
      #  TAB 5: POPULATION SURFACES
      # ════════════════════════════════════════════════════
      tabItem(tabName = "surfaces",
        fluidRow(
          box(width = 12, status = "primary", solidHeader = TRUE,
              title = span(icon("mountain"), " Association Surfaces by Transition"),
            uiOutput("ui_surface_transition"),
            fluidRow(
              column(6, plotlyOutput("surface_3d", height = "420px")),
              column(6,
                plotOutput("surface_heatmap", height = "200px"),
                plotOutput("marginal_slices", height = "200px")
              )
            ),
            hr(),
            uiOutput("surface_interpretation")
          )
        )
      ),

      # ════════════════════════════════════════════════════
      #  TAB 6: GUIDE
      # ════════════════════════════════════════════════════
      tabItem(tabName = "guide",
        fluidRow(
          box(width = 12, status = "success", solidHeader = TRUE,
              title = span(icon("book-open"), " How This Works"),
            fluidRow(
              column(6,
                h3("The Joint Modeling Approach", style = "color:#1b4965;"),
                p("This app implements", tags$b("transition-specific joint models"),
                  "for multi-state disease progression. The key idea:"),
                tags$ol(
                  tags$li(tags$b("Longitudinal submodel:"),
                    "For each biomarker (eGFR, BNP, HbA1c), a mixed-effects model
                     estimates subject-specific trajectories, separating signal from noise."),
                  tags$li(tags$b("Latent summaries:"),
                    "Subject-specific BLUPs (fitted values) η_i(t) represent
                     the 'true' biomarker state at any time."),
                  tags$li(tags$b("Survival submodel:"),
                    "For each transition (e.g., CKD→CVD), a GAM-Cox model links
                     the latent biomarker summaries to the transition hazard via a
                     semi-parametric surface f(η₁, η₂)."),
                  tags$li(tags$b("Personalized prediction:"),
                    "Given a patient's observed biomarker history up to time t,
                     their latent trajectories are projected forward, and the
                     transition-specific hazards yield personalized risk estimates
                     for each possible next event.")
                ),
                div(class = "code-block",
                  HTML("h<sub>transition</sub>(t | η(t)) = h₀(t) · exp{ γ'w + f(η_eGFR(t), η_BNP(t), η_HbA1c(t)) }"))
              ),
              column(6,
                h3("Expected CSV Format", style = "color:#1b4965;"),
                h4("longitudinal_biomarkers.csv"),
                div(class = "code-block",
                  HTML("patient_id, visit_time_years, biomarker, value, unit<br>
                        1, 0.0, eGFR, 81.4, mL/min/1.73m2<br>
                        1, 0.0, BNP, 203.1, pg/mL<br>
                        1, 0.0, HbA1c, 5.3, %<br>
                        1, 0.87, eGFR, 78.8, mL/min/1.73m2<br>...")),
                h4("survival_events.csv"),
                div(class = "code-block",
                  HTML("patient_id, ..., start_time, stop_time, status,<br>
                        event_type, state_from, state_to, transition, ...<br>
                        1, ..., 0.0, 2.17, 1, Diabetes, CVD, Diabetes, CVD→Diabetes, ...<br>
                        1, ..., 2.17, 4.01, 0, Censored, Diabetes, Diabetes, ...<br>...")),
                hr(),
                h4("What the prediction means"),
                p("When you select a patient and a prediction horizon, the app answers:"),
                tags$ul(
                  tags$li(tags$b("'Given this patient's current state and biomarker history,
                           what is their risk of transitioning to CKD, CVD, or Diabetes
                           within the next X years?'")),
                  tags$li("Risks are computed for each possible NEXT transition from the patient's current state."),
                  tags$li("The risk accounts for the patient's", tags$b("full biomarker trajectory,"),
                    "not just the last measurement.")
                )
              )
            )
          )
        )
      )

    ) # end tabItems
  ) # end body
)


# ═══════════════════════════════════════════════════════════════
#  SERVER
# ═══════════════════════════════════════════════════════════════

# ═══════════════════════════════════════════════════════════════
#  GLOBAL PLOT THEME — Publication Quality
# ═══════════════════════════════════════════════════════════════
theme_app <- function(base_size = 16) {
  theme_minimal(base_size = base_size) %+replace%
    theme(
      text = element_text(family = "sans", colour = "#1b2838"),
      plot.title = element_text(face = "bold", size = base_size + 4,
                                 colour = "#1b4965", margin = margin(b = 8)),
      plot.subtitle = element_text(size = base_size - 2, colour = "#555",
                                    margin = margin(b = 12)),
      axis.title = element_text(face = "bold", size = base_size, colour = "#333"),
      axis.text = element_text(size = base_size - 2, colour = "#444", face = "bold"),
      strip.text = element_text(face = "bold", size = base_size + 1,
                                 colour = "#1b4965"),
      legend.text = element_text(size = base_size - 3, face = "bold"),
      legend.title = element_text(face = "bold", size = base_size - 1),
      legend.position = "bottom",
      panel.grid.major = element_line(colour = "#e8edf2", linewidth = 0.4),
      panel.grid.minor = element_blank(),
      plot.background = element_rect(fill = "white", colour = NA),
      panel.background = element_rect(fill = "#fafbfd", colour = NA)
    )
}

METHOD_COLORS <- c(
  "1. Two-Stage BLUP + GAM-Cox" = "#1b4965",
  "2. True JM (ML-EM)" = "#e76f51",
  "3. Landmark (LOCF)" = "#2a9d8f",
  "4. LM 2.0 (BLUP+Slope)" = "#b5838d",
  "5. Bayesian JM (JMbayes2)" = "#264653",
  "6. jmBIG (Scalable Bayesian)" = "#f4a261",
  "7. JMbdirect (Bidirectional)" = "#9b2226"
)
METHOD_LINETYPES <- c("solid","longdash","dashed","dotdash","twodash","dotted","solid")

# ═══════════════════════════════════════════════════════════════
#  BUILT-IN SIMULATOR: Multi-state CKD/CVD/Diabetes data
# ═══════════════════════════════════════════════════════════════
generate_simulated_data <- function(n_patients = 500) {
  set.seed(42)
  diseases <- c("CKD","CVD","Diabetes")
  states <- c("At-risk", diseases)

  # ── Demographic lookup tables ──
  ethnicity_pool <- c("White","Black/African American","Hispanic/Latino",
                       "Asian","Other/Mixed")
  ethnicity_probs <- c(0.42, 0.22, 0.20, 0.10, 0.06)

  insurance_pool <- c("Medicare","Medicaid","Private","Uninsured")
  insurance_probs <- c(0.32, 0.18, 0.40, 0.10)

  education_pool <- c("Less than HS","High school","Some college","Bachelor+")
  education_probs <- c(0.12, 0.28, 0.30, 0.30)

  region_pool <- c("Northeast","South","Midwest","West")
  region_probs <- c(0.22, 0.32, 0.24, 0.22)

  # --- Longitudinal biomarkers ---
  long_rows <- list()
  surv_rows <- list()

  for (i in 1:n_patients) {
    # ── Demographics ──
    age <- runif(1, 35, 85)
    sex <- sample(0:1, 1)                        # 0=Female, 1=Male
    ethnicity <- sample(ethnicity_pool, 1, prob = ethnicity_probs)
    smoking <- sample(0:2, 1, prob = c(0.45, 0.30, 0.25))  # 0=Never,1=Former,2=Current
    bmi <- round(rnorm(1, 28 + ifelse(ethnicity=="Black/African American",2,0)
                          + ifelse(sex==1,-0.5,0.5), 5), 1)
    bmi <- pmax(bmi, 17)
    insurance <- sample(insurance_pool, 1, prob = insurance_probs)
    education <- sample(education_pool, 1, prob = education_probs)
    region <- sample(region_pool, 1, prob = region_probs)

    # Comorbidity flags (influenced by demographics)
    hypertension <- rbinom(1, 1, pmin(0.15 + 0.008*(age-40) +
                    ifelse(ethnicity=="Black/African American",0.12,0), 0.85))
    diabetes_hx  <- rbinom(1, 1, pmin(0.05 + 0.004*(age-40) +
                    ifelse(ethnicity=="Hispanic/Latino",0.06,0) +
                    ifelse(bmi>30,0.08,0), 0.60))

    # Age group (derived)
    age_group <- cut(age, breaks = c(0,45,55,65,75,Inf),
                     labels = c("<45","45-54","55-64","65-74","75+"),
                     right = FALSE)
    bmi_cat <- cut(bmi, breaks = c(0,18.5,25,30,35,Inf),
                   labels = c("Underweight","Normal","Overweight","Obese I","Obese II+"),
                   right = FALSE)

    # Entry disease — influenced by demographics
    entry_probs <- c(CKD=0.40, CVD=0.35, Diabetes=0.25)
    if (ethnicity == "Black/African American") entry_probs["CKD"] <- 0.52
    if (ethnicity == "Hispanic/Latino") entry_probs["Diabetes"] <- 0.36
    if (sex == 1) entry_probs["CVD"] <- entry_probs["CVD"] + 0.06
    if (age > 65) entry_probs["CKD"] <- entry_probs["CKD"] + 0.05
    if (hypertension == 1) entry_probs["CVD"] <- entry_probs["CVD"] + 0.05
    entry_probs <- entry_probs / sum(entry_probs)
    entry <- sample(diseases, 1, prob = entry_probs)

    # ── Biomarker parameters (patient-specific, influenced by demographics) ──
    eth_egfr_adj <- ifelse(ethnicity=="Black/African American", 8, 0)  # GFR race adjustment
    egfr_int <- rnorm(1, 90 + eth_egfr_adj - (age-50)*0.4 - ifelse(hypertension==1,5,0), 12)
    egfr_slp <- rnorm(1, -2.5 - ifelse(diabetes_hx==1,0.8,0) -
                        ifelse(ethnicity=="Black/African American",0.5,0), 1.5)

    bnp_int  <- rnorm(1, 70 + (age-50)*1.5 + ifelse(sex==1,10,0) +
                        ifelse(hypertension==1,30,0), 25)
    bnp_slp  <- rnorm(1, 5 + ifelse(age>70,2,0), 3)

    hba1c_int <- rnorm(1, 5.8 + ifelse(diabetes_hx==1,1.5,0) +
                         ifelse(ethnicity=="Hispanic/Latino",0.2,0) +
                         ifelse(bmi>30,0.3,0), 0.7)
    hba1c_slp <- rnorm(1, 0.12 + ifelse(diabetes_hx==1,0.08,0), 0.10)

    # Visit times
    n_visits <- sample(5:18, 1)
    v_times <- sort(cumsum(c(0, rexp(n_visits-1, 2))))
    v_times <- v_times[v_times <= 15]
    if (length(v_times) < 3) v_times <- c(0, 0.5, 1)

    for (vt in v_times) {
      if (runif(1) > 0.04)
        long_rows[[length(long_rows)+1]] <- data.frame(
          patient_id=i, visit_time_years=round(vt,3), biomarker="eGFR",
          value=round(max(5, egfr_int + egfr_slp*vt + rnorm(1,0,5)),1),
          unit="mL/min/1.73m2", stringsAsFactors=FALSE)
      if (runif(1) > 0.04)
        long_rows[[length(long_rows)+1]] <- data.frame(
          patient_id=i, visit_time_years=round(vt,3), biomarker="BNP",
          value=round(max(5, bnp_int + bnp_slp*vt + rnorm(1,0,15)),1),
          unit="pg/mL", stringsAsFactors=FALSE)
      if (runif(1) > 0.04)
        long_rows[[length(long_rows)+1]] <- data.frame(
          patient_id=i, visit_time_years=round(vt,3), biomarker="HbA1c",
          value=round(max(4, min(14, hba1c_int + hba1c_slp*vt + rnorm(1,0,0.3))),1),
          unit="%", stringsAsFactors=FALSE)
    }

    # --- Multi-state survival process ---
    current <- entry; t_now <- 0; max_t <- 15
    visited <- entry; traj <- entry
    ev_num <- 0

    # Demographic-influenced hazard multiplier
    demo_hr <- exp(0.01*(age-55) + 0.08*smoking + 0.02*pmax(bmi-25,0) +
                   ifelse(hypertension==1,0.15,0) +
                   ifelse(ethnicity=="Black/African American",0.10,0))

    while (t_now < max_t && ev_num < 3) {
      possible <- setdiff(states, current)
      rates <- sapply(possible, function(s) {
        base <- 0.08
        if (current=="CKD" && s=="CVD") base <- 0.12
        if (current=="CVD" && s=="CKD") base <- 0.10
        if (current=="Diabetes" && s=="CVD") base <- 0.11
        if (current=="CVD" && s=="Diabetes") base <- 0.09
        if (current=="Diabetes" && s=="CKD") base <- 0.10
        base * demo_hr
      })

      total_rate <- sum(rates)
      wait <- rexp(1, total_rate)
      t_next <- t_now + wait

      make_row <- function(stat, etype, sto, trans) {
        data.frame(
          patient_id=i, start_time=round(t_now,3), stop_time=round(ifelse(stat==0,max_t,t_next),3),
          status=stat, event_type=etype, state_from=current,
          state_to=ifelse(stat==0,NA,sto),
          transition=trans, entry_disease=entry,
          trajectory=traj, age_baseline=round(age,1), sex=sex,
          ethnicity=ethnicity, smoking=smoking, bmi=round(bmi,1),
          age_group=as.character(age_group), bmi_category=as.character(bmi_cat),
          hypertension=hypertension, diabetes_hx=diabetes_hx,
          insurance=insurance, education=education, region=region,
          stringsAsFactors=FALSE)
      }

      if (t_next >= max_t) {
        surv_rows[[length(surv_rows)+1]] <- make_row(0, "censored", NA,
          paste(current,"→","censored"))
        break
      }

      next_state <- sample(possible, 1, prob = rates/total_rate)
      ev_num <- ev_num + 1
      traj <- paste0(traj, " → ", next_state)

      surv_rows[[length(surv_rows)+1]] <- make_row(1, next_state, next_state,
        paste(current,"→",next_state))

      current <- next_state
      t_now <- t_next
    }

    if (t_now < max_t && ev_num >= 3) {
      surv_rows[[length(surv_rows)+1]] <- make_row(0, "censored", NA,
        paste(current,"→","censored"))
    }
  }

  long_df <- do.call(rbind, long_rows)
  surv_df <- do.call(rbind, surv_rows)
  list(long_df = long_df, surv_df = surv_df)
}


server <- function(input, output, session) {

  rv <- reactiveValues(
    long_raw = NULL,
    surv_raw = NULL,
    merged = FALSE,
    long_df = NULL,
    surv_df = NULL,
    patient_summary = NULL,
    transitions = NULL,
    lme_fits = NULL,
    gam_fits = NULL,      # list of transition-specific GAM-Cox fits
    eta_dfs = NULL,        # latent summaries per transition
    fit_done = FALSE
  )


  # ════════════════════════════════════════════════════
  #  UPLOAD & MERGE
  # ════════════════════════════════════════════════════

  observeEvent(input$file_long, {
    rv$long_raw <- tryCatch(read.csv(input$file_long$datapath, stringsAsFactors = FALSE),
                             error = function(e) NULL)
  })
  observeEvent(input$file_surv, {
    rv$surv_raw <- tryCatch(read.csv(input$file_surv$datapath, stringsAsFactors = FALSE),
                             error = function(e) NULL)
  })

  # ── Simulate data button ──
  observeEvent(input$btn_simulate, {
    n <- input$sim_n_patients
    withProgress(message = paste("Simulating", n, "patients..."), value = 0.3, {
      sim <- generate_simulated_data(n)
      incProgress(0.5, detail = "Processing...")
      rv$long_raw <- sim$long_df
      rv$surv_raw <- sim$surv_df
      incProgress(0.2, detail = "Done!")
    })
    showNotification(
      paste0("✓ Simulated data loaded: ", n, " patients, ",
             nrow(rv$long_raw), " biomarker rows, ",
             nrow(rv$surv_raw), " survival rows. Click 'Merge & Validate' to continue."),
      type = "message", duration = 8)
  })

  observeEvent(input$btn_merge, {
    req(rv$long_raw, rv$surv_raw)

    long <- rv$long_raw
    surv <- rv$surv_raw

    # Validate required columns
    long_ok <- all(c("patient_id","visit_time_years","biomarker","value") %in% names(long))
    surv_ok <- all(c("patient_id","start_time","stop_time","status",
                      "event_type","state_from","state_to","transition") %in% names(surv))

    if (!long_ok || !surv_ok) {
      showNotification("Column names don't match expected format. See Guide tab.", type = "error")
      return()
    }

    # Check ID overlap
    long_ids <- unique(long$patient_id)
    surv_ids <- unique(surv$patient_id)
    shared_ids <- intersect(long_ids, surv_ids)

    if (length(shared_ids) < 10) {
      showNotification(paste("Only", length(shared_ids), "shared patient_ids."), type = "error")
      return()
    }

    # Filter to shared IDs
    rv$long_df <- long[long$patient_id %in% shared_ids, ]
    rv$surv_df <- surv[surv$patient_id %in% shared_ids, ]

    # Derive total_events and censor_time if not present
    if (!"total_events" %in% names(rv$surv_df)) {
      ev_counts <- rv$surv_df %>% group_by(patient_id) %>%
        summarise(.total = sum(status == 1), .groups = "drop")
      rv$surv_df <- merge(rv$surv_df, ev_counts, by = "patient_id", all.x = TRUE)
      rv$surv_df$total_events <- rv$surv_df$.total
      rv$surv_df$.total <- NULL
    }
    if (!"censor_time" %in% names(rv$surv_df)) {
      max_times <- rv$surv_df %>% group_by(patient_id) %>%
        summarise(.censor = max(as.numeric(stop_time), na.rm=TRUE), .groups="drop")
      rv$surv_df <- merge(rv$surv_df, max_times, by = "patient_id", all.x = TRUE)
      rv$surv_df$censor_time <- rv$surv_df$.censor
      rv$surv_df$.censor <- NULL
    }

    # Build patient summary
    ps <- rv$surv_df %>%
      group_by(patient_id) %>%
      summarise(
        entry_disease = first(entry_disease),
        trajectory = first(trajectory),
        total_events = max(as.numeric(total_events), na.rm = TRUE),
        censor_time = max(as.numeric(censor_time), na.rm = TRUE),
        current_state = last(state_from[status == 0]),
        age_baseline = first(age_baseline),
        sex = first(sex),
        ethnicity = if ("ethnicity" %in% names(rv$surv_df)) first(ethnicity) else NA_character_,
        smoking = if ("smoking" %in% names(rv$surv_df)) first(smoking) else NA_integer_,
        bmi = if ("bmi" %in% names(rv$surv_df)) first(bmi) else NA_real_,
        age_group = if ("age_group" %in% names(rv$surv_df)) first(age_group) else NA_character_,
        bmi_category = if ("bmi_category" %in% names(rv$surv_df)) first(bmi_category) else NA_character_,
        hypertension = if ("hypertension" %in% names(rv$surv_df)) first(hypertension) else NA_integer_,
        diabetes_hx = if ("diabetes_hx" %in% names(rv$surv_df)) first(diabetes_hx) else NA_integer_,
        insurance = if ("insurance" %in% names(rv$surv_df)) first(insurance) else NA_character_,
        education = if ("education" %in% names(rv$surv_df)) first(education) else NA_character_,
        region = if ("region" %in% names(rv$surv_df)) first(region) else NA_character_,
        .groups = "drop"
      )
    # Derive age_group if not present but age_baseline is
    if (all(is.na(ps$age_group)) && !all(is.na(ps$age_baseline))) {
      ps$age_group <- as.character(cut(ps$age_baseline,
        breaks = c(0,45,55,65,75,Inf),
        labels = c("<45","45-54","55-64","65-74","75+"), right = FALSE))
    }
    # Derive bmi_category if not present but bmi is
    if (all(is.na(ps$bmi_category)) && !all(is.na(ps$bmi))) {
      ps$bmi_category <- as.character(cut(ps$bmi,
        breaks = c(0,18.5,25,30,35,Inf),
        labels = c("Underweight","Normal","Overweight","Obese I","Obese II+"), right = FALSE))
    }
    # Sex label
    ps$sex_label <- ifelse(ps$sex == 0, "Female", ifelse(ps$sex == 1, "Male", "Other"))
    # Smoking label
    ps$smoking_label <- ifelse(is.na(ps$smoking), "Unknown",
                               c("Never","Former","Current")[ps$smoking + 1])
    # If current_state is NA, use last state_to
    ps$current_state[is.na(ps$current_state)] <-
      rv$surv_df %>% group_by(patient_id) %>%
      summarise(cs = last(state_to), .groups = "drop") %>%
      filter(patient_id %in% ps$patient_id[is.na(ps$current_state)]) %>%
      pull(cs)

    rv$patient_summary <- ps

    # Extract unique transitions (status == 1 only)
    actual_events <- rv$surv_df[rv$surv_df$status == 1, ]
    rv$transitions <- actual_events %>%
      group_by(state_from, state_to, transition) %>%
      summarise(n = n(), .groups = "drop") %>%
      arrange(desc(n))

    rv$merged <- TRUE
    showNotification(paste("Merged!", length(shared_ids), "patients."), type = "message")
  })

  output$merge_status <- renderUI({
    if (!rv$merged) {
      return(div(style = "color:#888;", icon("info-circle"),
                 " Upload both files and click Merge."))
    }
    n_pat <- nrow(rv$patient_summary)
    n_long <- nrow(rv$long_df)
    n_surv <- nrow(rv$surv_df)
    n_trans <- nrow(rv$transitions)
    div(
      p(class = "status-ok", icon("check-circle"), paste(" Merge successful!")),
      p(paste(n_pat, "patients |", n_long, "longitudinal rows |", n_surv, "survival rows")),
      p(paste(n_trans, "unique transitions detected"))
    )
  })

  output$merge_summary <- renderUI({
    req(rv$merged)
    ps <- rv$patient_summary
    n <- nrow(ps)
    entries <- table(ps$entry_disease)

    state_colors <- c("CKD"="#457b9d","CVD"="#e76f51","Diabetes"="#e9c46a","At-risk"="#2a9d8f")

    entry_badges <- lapply(names(entries), function(e) {
      bg <- state_colors[e]; if(is.na(bg)) bg <- "#666"
      tags$span(style = paste0("background:", bg, "; color:white; padding:4px 12px;
                                border-radius:14px; font-weight:700; font-size:13px;
                                margin-right:6px; display:inline-block; margin-bottom:4px;"),
                paste0(e, ": ", entries[e], " (", round(100*entries[e]/n, 1), "%)"))
    })

    biomarker_badges <- lapply(unique(rv$long_df$biomarker), function(bm) {
      bc <- c("eGFR"="#457b9d","BNP"="#e76f51","HbA1c"="#e9c46a")
      bg <- bc[bm]; if(is.na(bg)) bg <- "#1b4965"
      tags$span(style = paste0("background:", bg, "; color:white; padding:3px 10px;
                                border-radius:10px; font-weight:600; font-size:12px;
                                margin-right:4px;"), bm)
    })

    div(class = "merge-summary",
      tags$h3(style = "margin-top:0; color:#1b4965; font-weight:800;",
              icon("database"), paste("", n, "patients")),
      tags$div(style = "margin:10px 0;",
        tags$span(style = "font-weight:700; font-size:14px; color:#333;", "Entry diseases: "),
        tagList(entry_badges)),
      tags$div(style = "margin:10px 0;",
        tags$span(style = "font-weight:700; font-size:14px; color:#333;", "Biomarkers: "),
        tagList(biomarker_badges)),
      # ── Demographic summary ──
      if (!all(is.na(ps$sex_label))) {
        sex_tab <- table(ps$sex_label)
        tags$div(style = "margin:10px 0;",
          tags$span(style = "font-weight:700; font-size:14px; color:#333;", "Sex: "),
          tags$span(style = "font-size:13px; color:#555;",
            paste(sapply(names(sex_tab), function(s)
              paste0(s, ": ", sex_tab[s], " (", round(100*sex_tab[s]/n,1), "%)")),
              collapse = " | ")))
      },
      if (!all(is.na(ps$ethnicity))) {
        eth_tab <- sort(table(ps$ethnicity), decreasing = TRUE)
        tags$div(style = "margin:10px 0;",
          tags$span(style = "font-weight:700; font-size:14px; color:#333;", "Ethnicity: "),
          tags$span(style = "font-size:13px; color:#555;",
            paste(sapply(names(eth_tab), function(s)
              paste0(s, ": ", eth_tab[s], " (", round(100*eth_tab[s]/n,1), "%)")),
              collapse = " | ")))
      },
      if (!all(is.na(ps$age_group))) {
        age_tab <- table(factor(ps$age_group, levels = c("<45","45-54","55-64","65-74","75+")))
        tags$div(style = "margin:10px 0;",
          tags$span(style = "font-weight:700; font-size:14px; color:#333;", "Age groups: "),
          tags$span(style = "font-size:13px; color:#555;",
            paste(sapply(names(age_tab), function(s)
              paste0(s, ": ", age_tab[s])), collapse = " | ")))
      },
      tags$div(style = "margin:10px 0;",
        tags$span(style = "font-weight:700; font-size:14px; color:#333;", "Follow-up: "),
        tags$span(style = "font-size:14px; color:#555;",
          paste0("median ", round(median(ps$censor_time), 1), " years",
                 " (range ", round(min(ps$censor_time), 1), "–",
                 round(max(ps$censor_time), 1), ")")))
    )
  })

  output$transition_counts_table <- renderUI({
    req(rv$transitions)
    tc <- rv$transitions %>% arrange(desc(n))
    state_colors <- c("CKD"="#457b9d","CVD"="#e76f51","Diabetes"="#e9c46a","At-risk"="#2a9d8f")

    rows <- lapply(1:nrow(tc), function(k) {
      r <- tc[k, ]
      from_col <- state_colors[r$state_from]; if(is.na(from_col)) from_col <- "#666"
      to_col <- state_colors[r$state_to]; if(is.na(to_col)) to_col <- "#666"

      # Count bar proportional
      max_n <- max(tc$n)
      bar_pct <- round(r$n / max_n * 100)

      tags$tr(
        tags$td(style = "padding:8px 12px;",
          tags$span(style = paste0("background:", from_col, "; color:white; padding:3px 10px;
                                    border-radius:12px; font-weight:600; font-size:12px;"),
                    r$state_from)),
        tags$td(style = "padding:8px; text-align:center; font-size:16px; color:#999;", "→"),
        tags$td(style = "padding:8px 12px;",
          tags$span(style = paste0("background:", to_col, "; color:white; padding:3px 10px;
                                    border-radius:12px; font-weight:600; font-size:12px;"),
                    r$state_to)),
        tags$td(style = "padding:8px 14px; width:45%;",
          tags$div(style = "display:flex; align-items:center; gap:8px;",
            tags$div(style = paste0("height:20px; width:", bar_pct, "%; background:linear-gradient(90deg,",
                                     from_col, ",", to_col, "); border-radius:10px; min-width:20px;")),
            tags$span(style = "font-weight:800; font-size:15px; color:#1b4965;", r$n)
          )
        )
      )
    })

    tags$table(class = "model-results-table",
      tags$thead(
        tags$tr(
          tags$th("From"), tags$th(""), tags$th("To"),
          tags$th(style = "text-align:left;", "Count")
        )
      ),
      tags$tbody(rows)
    )
  })

  output$event_burden_plot <- renderPlot({
    req(rv$patient_summary)
    df <- data.frame(table(rv$patient_summary$total_events))
    names(df) <- c("Events", "Count")
    ggplot(df, aes(x = Events, y = Count, fill = Events)) +
      geom_col(width = 0.6) +
      geom_text(aes(label = Count), vjust = -0.5, fontface = "bold", size = 5) +
      scale_fill_manual(values = c("0"="#a8dadc","1"="#457b9d","2"="#e76f51","3"="#c1121f")) +
      labs(x = "Number of Disease Events", y = "Patients") +
      theme_minimal(base_size = 15) +
      theme(legend.position = "none",
            axis.title = element_text(face = "bold", size = 14),
            axis.text = element_text(face = "bold", size = 12))
  })


  # ════════════════════════════════════════════════════
  #  DATA EXPLORER
  # ════════════════════════════════════════════════════

  output$ui_entry_filter <- renderUI({
    req(rv$patient_summary)
    entries <- c("All", sort(unique(rv$patient_summary$entry_disease)))
    selectInput("filter_entry", "Entry Disease", choices = entries, selected = "All")
  })

  output$ui_trajectory_filter <- renderUI({
    req(rv$patient_summary)
    trajs <- c("All", sort(unique(rv$patient_summary$trajectory)))
    selectInput("filter_traj", "Trajectory", choices = trajs, selected = "All")
  })

  filtered_ids <- reactive({
    req(rv$patient_summary)
    ps <- rv$patient_summary
    if (!is.null(input$filter_entry) && input$filter_entry != "All")
      ps <- ps[ps$entry_disease == input$filter_entry, ]
    if (!is.null(input$filter_traj) && input$filter_traj != "All")
      ps <- ps[ps$trajectory == input$filter_traj, ]
    n <- min(input$explore_n_patients %||% 15, nrow(ps))
    sample(ps$patient_id, n)
  })

  make_biomarker_plot <- function(marker) {
    req(rv$long_df)
    ids <- filtered_ids()
    df <- rv$long_df[rv$long_df$patient_id %in% ids & rv$long_df$biomarker == marker, ]
    if (nrow(df) == 0) return(NULL)

    # Get events for these patients
    evts <- rv$surv_df[rv$surv_df$patient_id %in% ids & rv$surv_df$status == 1, ]

    p <- ggplot(df, aes(x = visit_time_years, y = value,
                         group = patient_id, colour = factor(patient_id))) +
      geom_line(linewidth = 1.2, alpha = 0.75) +
      geom_point(size = 2.5, alpha = 0.8) +
      labs(title = paste(marker, "Trajectories"),
           x = "Time (years)", y = marker) +
      theme_minimal(base_size = 15) +
      theme(legend.position = "none",
            plot.title = element_text(face = "bold", size = 18, colour = "#1b4965"),
            axis.title = element_text(face = "bold", size = 14),
            axis.text = element_text(size = 12, face = "bold"))

    # Add event markers
    if (nrow(evts) > 0) {
      evt_marks <- merge(evts[, c("patient_id","stop_time","event_type")],
                          df[, c("patient_id","visit_time_years","value")],
                          by = "patient_id", all.x = FALSE)
      evt_marks$tdiff <- abs(evt_marks$visit_time_years - evt_marks$stop_time)
      evt_closest <- evt_marks %>% group_by(patient_id, stop_time) %>%
        slice_min(tdiff, n = 1) %>% ungroup()

      if (nrow(evt_closest) > 0) {
        p <- p + geom_point(data = evt_closest,
                             aes(x = stop_time, y = value, shape = event_type),
                             size = 5, colour = "black", stroke = 1.5) +
          scale_shape_manual(values = c("CKD" = 17, "CVD" = 15,
                                         "Diabetes" = 18, "Death" = 4),
                              name = "Event")
      }
    }
    p
  }

  output$plot_egfr  <- renderPlot({ make_biomarker_plot("eGFR") })
  output$plot_bnp   <- renderPlot({ make_biomarker_plot("BNP") })
  output$plot_hba1c <- renderPlot({ make_biomarker_plot("HbA1c") })

  output$plot_all_markers <- renderPlot({
    req(rv$long_df)
    ids <- filtered_ids()
    df <- rv$long_df[rv$long_df$patient_id %in% ids, ]
    ggplot(df, aes(x = visit_time_years, y = value,
                    group = patient_id, colour = factor(patient_id))) +
      geom_line(linewidth = 1.1, alpha = 0.7) +
      geom_point(size = 2, alpha = 0.75) +
      facet_wrap(~ biomarker, scales = "free_y", ncol = 3) +
      labs(x = "Time (years)", y = "Value") +
      theme_minimal(base_size = 15) +
      theme(legend.position = "none",
            strip.text = element_text(face = "bold", size = 16, colour = "#1b4965"),
            axis.title = element_text(face = "bold", size = 14),
            axis.text = element_text(size = 12, face = "bold"))
  })

  output$cumulative_incidence_plot <- renderPlot({
    req(rv$surv_df)
    events <- rv$surv_df[rv$surv_df$status == 1 &
                          rv$surv_df$event_type %in% c("CKD","CVD","Diabetes"), ]
    tseq <- seq(0, max(events$stop_time, na.rm = TRUE), length.out = 100)
    n_total <- length(unique(rv$surv_df$patient_id))

    ci_df <- do.call(rbind, lapply(c("CKD","CVD","Diabetes"), function(evt) {
      data.frame(
        time = tseq,
        ci = sapply(tseq, function(t)
          length(unique(events$patient_id[events$event_type == evt & events$stop_time <= t])) / n_total),
        event = evt
      )
    }))

    ggplot(ci_df, aes(x = time, y = ci, colour = event)) +
      geom_line(linewidth = 2) +
      scale_colour_manual(values = c("CKD"="#457b9d","CVD"="#e76f51","Diabetes"="#e9c46a")) +
      labs(title = "Cumulative Incidence by Event Type",
           x = "Time (years)", y = "Cumulative Incidence", colour = "Event") +
      theme_app(base_size = 15)
  })

  output$biomarker_boxplot <- renderPlot({
    req(rv$long_df, rv$patient_summary)
    # Get last measurement per patient per biomarker
    last_meas <- rv$long_df %>%
      group_by(patient_id, biomarker) %>%
      slice_max(visit_time_years, n = 1) %>%
      ungroup()
    last_meas <- merge(last_meas, rv$patient_summary[, c("patient_id","current_state")],
                        by = "patient_id")

    ggplot(last_meas, aes(x = current_state, y = value, fill = current_state)) +
      geom_boxplot(alpha = 0.7, outlier.size = 1.5) +
      facet_wrap(~ biomarker, scales = "free_y") +
      scale_fill_manual(values = c("At-risk"="#2a9d8f","CKD"="#457b9d",
                                    "CVD"="#e76f51","Diabetes"="#e9c46a")) +
      labs(title = "Biomarker by Current State", x = "", y = "Value") +
      theme_app(base_size = 14) + theme(legend.position = "none")
  })


  # ════════════════════════════════════════════════════
  #  FIT JOINT MODELS
  # ════════════════════════════════════════════════════

  output$ui_select_transitions <- renderUI({
    req(rv$transitions)
    trans_choices <- rv$transitions$transition
    checkboxGroupInput("sel_transitions", "Select transitions to model:",
                        choices = trans_choices,
                        selected = trans_choices[1:min(6, length(trans_choices))])
  })

  observeEvent(input$btn_fit, {
    req(rv$long_df, rv$surv_df, input$sel_transitions)

    long <- rv$long_df
    surv <- rv$surv_df
    markers <- unique(long$biomarker)

    withProgress(message = "Fitting models...", value = 0, {

      # Step 1: Fit longitudinal submodels (shared across transitions)
      incProgress(0.1, detail = "Longitudinal submodels...")
      lme_fits <- list()
      for (mk in markers) {
        df_mk <- long[long$biomarker == mk, ]
        names(df_mk)[names(df_mk) == "value"] <- "y"
        names(df_mk)[names(df_mk) == "visit_time_years"] <- "time"
        names(df_mk)[names(df_mk) == "patient_id"] <- "id"
        df_mk$id <- factor(df_mk$id)

        lme_fits[[mk]] <- tryCatch(
          lme(y ~ time, random = ~ time | id, data = df_mk,
              control = lmeControl(opt = "optim", maxIter = 300)),
          error = function(e) {
            showNotification(paste("LME failed for", mk, ":", e$message), type = "warning")
            NULL
          }
        )
      }
      rv$lme_fits <- lme_fits

      # Step 2: For each selected transition, build eta and fit GAM-Cox
      gam_fits <- list()
      eta_dfs <- list()
      n_trans <- length(input$sel_transitions)

      for (k in seq_along(input$sel_transitions)) {
        tr <- input$sel_transitions[k]
        incProgress(0.8 / n_trans, detail = paste("Transition:", tr))

        # Get rows for this transition: event rows + censored rows in the from-state
        tr_parts <- strsplit(tr, " → ")[[1]]
        from_state <- trimws(tr_parts[1])
        to_state <- trimws(tr_parts[2])

        # Event rows: this specific transition
        event_rows <- surv[surv$transition == tr & surv$status == 1, ]

        # At-risk rows: anyone in from_state who hasn't transitioned to to_state
        # (censored in from_state, or transitioned elsewhere)
        at_risk_rows <- surv[surv$state_from == from_state & surv$status == 0, ]
        # Also people who went somewhere else from from_state
        other_trans <- surv[surv$state_from == from_state & surv$status == 1 &
                             surv$state_to != to_state, ]
        other_trans$status <- 0  # treat as censored for this transition

        req_cols <- c("patient_id","start_time","stop_time","status",
                       "age_baseline","sex")
        # Add optional demographic columns if they exist
        opt_demo_cols <- c("smoking","bmi","ethnicity","age_group","bmi_category",
                           "hypertension","diabetes_hx","insurance","education","region")
        req_cols <- c(req_cols, intersect(opt_demo_cols, names(surv)))
        analysis_df <- rbind(
          event_rows[, req_cols],
          at_risk_rows[, req_cols],
          other_trans[, req_cols]
        )

        if (nrow(analysis_df) < 20 || sum(analysis_df$status) < 10) {
          showNotification(paste("Too few events for", tr), type = "warning")
          next
        }

        # Remove duplicates (same patient may appear multiple times)
        analysis_df <- analysis_df[!duplicated(analysis_df$patient_id), ]

        # CRITICAL FIX: Compute time-in-state as the survival time
        # mgcv::gam(cox.ph) needs Surv(time, event) with POSITIVE times
        # time_in_state = how long the patient stayed in from_state before transitioning
        analysis_df$start_time <- as.numeric(analysis_df$start_time)
        analysis_df$stop_time  <- as.numeric(analysis_df$stop_time)
        analysis_df$time_in_state <- analysis_df$stop_time - analysis_df$start_time
        # Floor at small positive value to avoid zero/negative times
        analysis_df$time_in_state <- pmax(analysis_df$time_in_state, 0.01)

        # Get latent summaries at the MIDPOINT of the interval
        # (represents biomarker state during the at-risk period)
        analysis_df$eta_time <- (analysis_df$start_time + analysis_df$stop_time) / 2

        for (mk in markers) {
          if (is.null(lme_fits[[mk]])) next
          cm <- coef(lme_fits[[mk]])
          mk_clean <- gsub("[^A-Za-z0-9]", "", mk)
          pid_char <- as.character(analysis_df$patient_id)
          available_ids <- rownames(cm)
          matched <- pid_char %in% available_ids
          analysis_df[[paste0("eta_", mk_clean)]] <- NA_real_
          if (any(matched)) {
            analysis_df[[paste0("eta_", mk_clean)]][matched] <-
              cm[pid_char[matched], 1] + cm[pid_char[matched], 2] * analysis_df$eta_time[matched]
          }
        }

        # Remove rows with NA etas
        eta_cols <- grep("^eta_", names(analysis_df), value = TRUE)
        analysis_df <- analysis_df[complete.cases(analysis_df[, eta_cols]), ]

        if (nrow(analysis_df) < 20) next

        # Use time_in_state (always positive) as survival time
        analysis_df$surv_obj <- with(analysis_df, Surv(time_in_state, status))

        # Fit GAM-Cox
        if (length(eta_cols) >= 2) {
          formula_str <- paste0("surv_obj ~ age_baseline + sex + te(",
                                 eta_cols[1], ", ", eta_cols[2],
                                 ", k=c(5,5), bs='tp')")
          if (length(eta_cols) >= 3) {
            formula_str <- paste0("surv_obj ~ age_baseline + sex + te(",
                                   eta_cols[1], ", ", eta_cols[2],
                                   ", k=c(5,5), bs='tp') + s(",
                                   eta_cols[3], ", k=6, bs='tp')")
          }
        } else {
          formula_str <- paste0("surv_obj ~ age_baseline + sex + s(",
                                 eta_cols[1], ", k=8, bs='tp')")
        }

        gam_fit <- tryCatch(
          gam(as.formula(formula_str), family = cox.ph(),
              data = analysis_df, method = "REML"),
          error = function(e) {
            showNotification(paste("GAM failed for", tr), type = "warning")
            NULL
          }
        )

        if (!is.null(gam_fit)) {
          gam_fits[[tr]] <- gam_fit
          eta_dfs[[tr]] <- analysis_df
        }
      }

      rv$gam_fits <- gam_fits
      rv$eta_dfs <- eta_dfs
      rv$fit_done <- TRUE
    })
  })

  output$fit_progress_styled <- renderUI({
    if (!rv$fit_done) {
      return(tags$div(class = "pred-status-panel",
        tags$div(class = "pred-status-line pred-warn",
                 icon("hourglass-half"), " Models not yet fitted. Select transitions and click 'Fit'.")))
    }
    n <- length(rv$gam_fits)
    nms <- names(rv$gam_fits)
    tags$div(class = "pred-status-panel",
      tags$div(class = "pred-status-line pred-ok", style = "font-size:14px; font-weight:700;",
               icon("check-circle"), paste(" Successfully fitted", n, "transition models")),
      lapply(nms, function(nm) {
        tags$div(class = "pred-status-line", style = "color:#a8dadc; padding-left:16px;",
                 HTML(paste0("&#9656; ", nm)))
      })
    )
  })

  output$fit_results_styled <- renderUI({
    req(rv$fit_done, length(rv$gam_fits) > 0)
    rows <- lapply(names(rv$gam_fits), function(tr) {
      gf <- rv$gam_fits[[tr]]
      sm <- tryCatch(summary(gf), error = function(e) NULL)
      edf <- if (!is.null(sm) && nrow(sm$s.table) > 0) round(sm$s.table[1, "edf"], 1) else NA
      de  <- if (!is.null(sm)) paste0(round(sm$dev.expl * 100, 1), "%") else "?"
      n_events <- sum(rv$eta_dfs[[tr]]$status)
      n_total <- nrow(rv$eta_dfs[[tr]])

      # EDF badge color
      edf_bg <- if (is.na(edf)) "#ccc" else if (edf > 3) "#e76f51" else if (edf > 1.5) "#e9c46a" else "#2a9d8f"
      edf_badge <- tags$span(style = paste0(
        "background:", edf_bg, "; color:white; padding:3px 10px; border-radius:10px;
         font-weight:700; font-size:13px;"),
        ifelse(is.na(edf), "N/A", edf))

      de_badge <- tags$span(style = "background:#f0f7ff; color:#1b4965; padding:3px 10px;
                                     border-radius:10px; font-weight:700; font-size:13px;",
                             de)

      tags$tr(
        tags$td(style = "padding:10px 14px; font-weight:700; font-size:14px;
                         border-left:4px solid #1b4965;", tr),
        tags$td(style = "text-align:center; padding:10px; font-weight:600;", n_total),
        tags$td(style = "text-align:center; padding:10px;",
                tags$span(style="font-weight:700; color:#c1121f; font-size:15px;", n_events)),
        tags$td(style = "text-align:center; padding:10px;", edf_badge),
        tags$td(style = "text-align:center; padding:10px;", de_badge)
      )
    })

    tags$table(class = "model-results-table",
      tags$thead(
        tags$tr(
          tags$th("Transition"),
          tags$th(style = "text-align:center;", "N"),
          tags$th(style = "text-align:center;", "Events"),
          tags$th(style = "text-align:center;", "Surface EDF"),
          tags$th(style = "text-align:center;", "Dev. Explained")
        )
      ),
      tags$tbody(rows)
    )
  })

  output$fit_interpretation <- renderUI({
    req(rv$fit_done)
    n_fits <- length(rv$gam_fits)
    div(class = "interp-panel",
      h4(icon("lightbulb"), " Model Interpretation"),
      p(style = "font-size:14px;",
        paste(n_fits, "transition-specific joint models fitted.")),
      p(style = "font-size:13px;",
        tags$span(style="background:#e76f51; color:white; padding:2px 8px; border-radius:8px;
                         font-weight:700; font-size:12px;", "EDF > 3"),
        " indicates substantial nonlinearity — the biomarker-risk
         relationship is more complex than a simple linear effect."),
      p(style = "font-size:13px;",
        tags$span(style="background:#2a9d8f; color:white; padding:2px 8px; border-radius:8px;
                         font-weight:700; font-size:12px;", "EDF ≈ 1"),
        " suggests the relationship is approximately linear — standard parametric models would suffice."),
      p(style = "font-size:13px;",
        "Navigate to", tags$b("Tab 4"), "to select a patient and compute personalized risk,
         or", tags$b("Tab 5"), "to visualize the population-level association surfaces.")
    )
  })

  output$lme_summaries_styled <- renderUI({
    req(rv$lme_fits)
    cards <- lapply(names(rv$lme_fits), function(mk) {
      if (is.null(rv$lme_fits[[mk]])) return(NULL)
      fe <- round(fixef(rv$lme_fits[[mk]]), 4)
      vc <- round(as.numeric(VarCorr(rv$lme_fits[[mk]])[, "StdDev"]), 4)

      biomarker_colors <- c("eGFR" = "#457b9d", "BNP" = "#e76f51", "HbA1c" = "#e9c46a")
      bc <- biomarker_colors[mk]
      if (is.na(bc)) bc <- "#1b4965"

      tags$div(class = "lme-card", style = paste0("border-left-color:", bc, ";"),
        tags$h5(icon("dna"), paste("", mk)),
        tags$div(class = "lme-row",
          tags$span(class = "lme-label", "Intercept (β₀):"),
          tags$span(class = "lme-val", fe[1])
        ),
        tags$div(class = "lme-row",
          tags$span(class = "lme-label", "Time slope (β₁):"),
          tags$span(class = "lme-val", style = paste0("color:", if(fe[2]<0) "#c1121f" else "#2a9d8f"),
                    paste0(ifelse(fe[2]>0,"+",""), fe[2]))
        ),
        tags$div(class = "lme-row",
          tags$span(class = "lme-label", "Random effects SD:"),
          tags$span(class = "lme-val", paste(vc, collapse = " | "))
        )
      )
    })
    tagList(cards)
  })


  # ════════════════════════════════════════════════════
  #  PERSONALIZED RISK PREDICTION
  # ════════════════════════════════════════════════════

  output$ui_patient_select <- renderUI({
    req(rv$patient_summary)
    ps <- rv$patient_summary
    # Create informative labels with demographics
    labels <- paste0("ID ", ps$patient_id, " | ",
                      ps$entry_disease, " | ",
                      ps$sex_label,
                      ifelse(!is.na(ps$age_baseline), paste0(" ", round(ps$age_baseline,0), "y"), ""),
                      ifelse(!is.na(ps$ethnicity), paste0(" | ", ps$ethnicity), ""),
                      " | ", ps$total_events, " events")
    choices <- setNames(ps$patient_id, labels)
    selectizeInput("sel_patient", "Choose Patient:", choices = choices,
                    options = list(maxOptions = 500))
  })

  output$patient_info_card <- renderUI({
    req(input$sel_patient, rv$patient_summary, rv$long_df)
    pid <- as.numeric(input$sel_patient)
    ps <- rv$patient_summary[rv$patient_summary$patient_id == pid, ]
    if (nrow(ps) == 0) return(NULL)

    # Get events
    evts <- rv$surv_df[rv$surv_df$patient_id == pid & rv$surv_df$status == 1, ]
    # Last biomarker values
    last_vals <- rv$long_df %>%
      filter(patient_id == pid) %>%
      group_by(biomarker) %>%
      slice_max(visit_time_years, n = 1) %>%
      ungroup()

    state_class <- switch(as.character(ps$current_state),
      "CKD" = "trans-ckd", "CVD" = "trans-cvd",
      "Diabetes" = "trans-dm", "trans-ar")

    div(class = "patient-card",
      h4(paste0("Patient ", pid), style = "margin-top:0; color:#1b4965;"),
      p(tags$b("Entry: "), ps$entry_disease,
        tags$b(" | Current: "), span(class = paste("trans-arrow", state_class), ps$current_state)),
      # ── Demographics row ──
      p(tags$b("Age: "), round(ps$age_baseline,0), "y",
        if (!is.na(ps$sex_label)) tagList(tags$b(" | Sex: "), ps$sex_label),
        if (!is.na(ps$ethnicity)) tagList(tags$b(" | Ethnicity: "), ps$ethnicity)),
      if (!is.na(ps$bmi))
        p(tags$b("BMI: "), ps$bmi, paste0("(", ps$bmi_category, ")"),
          if (!is.na(ps$smoking_label)) tagList(tags$b(" | Smoking: "), ps$smoking_label),
          if (!is.na(ps$hypertension)) tagList(tags$b(" | HTN: "), ifelse(ps$hypertension==1,"Yes","No")),
          if (!is.na(ps$diabetes_hx)) tagList(tags$b(" | DM Hx: "), ifelse(ps$diabetes_hx==1,"Yes","No"))),
      if (!is.na(ps$insurance))
        p(tags$b("Insurance: "), ps$insurance,
          if (!is.na(ps$education)) tagList(tags$b(" | Education: "), ps$education),
          if (!is.na(ps$region)) tagList(tags$b(" | Region: "), ps$region)),
      p(tags$b("Trajectory: "), ps$trajectory),
      p(tags$b("Events: "), ps$total_events,
        tags$b(" | Follow-up: "), round(ps$censor_time, 1), "years"),
      if (nrow(evts) > 0) {
        tagList(
          p(tags$b("Event history:")),
          tags$ul(lapply(1:nrow(evts), function(k) {
            tags$li(paste0("t=", evts$stop_time[k], "y: ", evts$transition[k]))
          }))
        )
      },
      if (nrow(last_vals) > 0) {
        p(tags$b("Last biomarkers: "),
          paste(sapply(1:nrow(last_vals), function(k)
            paste0(last_vals$biomarker[k], "=", last_vals$value[k])),
            collapse = ", "))
      }
    )
  })

  # ══════════════════════════════════════════════════════════════
  #  MULTI-METHOD PREDICTION ENGINE
  #  5 approaches based on the joint modeling literature:
  #    1. Two-Stage BLUP + GAM-Cox (Tsiatis & Davidian 2004)
  #    2. True JM via JM package (Rizopoulos 2012, ML-EM)
  #    3. Landmark LOCF (van Houwelingen & Putter 2012)
  #    4. Landmarking 2.0 BLUP+Slope (Putter & van Houwelingen 2022)
  #    5. Bayesian JM via JMbayes2 (Rizopoulos et al. 2023)
  # ══════════════════════════════════════════════════════════════

  pred_results <- reactiveVal(NULL)
  pred_status <- reactiveVal("")

  has_JM_pkg <- reactive({ requireNamespace("JM", quietly = TRUE) })
  has_JMb2   <- reactive({ requireNamespace("JMbayes2", quietly = TRUE) })
  has_jmBIG  <- reactive({ requireNamespace("jmBIG", quietly = TRUE) })
  has_JMbd   <- reactive({ requireNamespace("JMbdirect", quietly = TRUE) })

  observeEvent(input$btn_predict, {
    req(rv$fit_done, input$sel_patient, rv$lme_fits, input$sel_methods)
    pid <- as.numeric(input$sel_patient)
    ps <- rv$patient_summary[rv$patient_summary$patient_id == pid, ]
    current_state <- ps$current_state
    landmark <- input$pred_landmark
    horizon  <- input$pred_horizon
    markers  <- names(rv$lme_fits)
    sel_methods <- input$sel_methods

    # Possible transitions from current state
    possible_tr <- names(rv$gam_fits)[
      startsWith(names(rv$gam_fits), paste0(current_state, " →"))]
    if (length(possible_tr) == 0) {
      showNotification(paste("No models for transitions from", current_state), type = "warning")
      return()
    }

    all_results <- list()
    status_msgs <- c()
    jmbd_already_fitted <- FALSE  # JMbdirect fits once for all transitions

    withProgress(message = "Computing predictions...", value = 0, {

      for (tr in possible_tr) {
        to_state <- strsplit(tr, " → ")[[1]][2]
        from_state <- current_state
        n_methods <- length(sel_methods)

        # ── METHOD 1: Two-Stage BLUP + GAM-Cox ──
        if ("twostage" %in% sel_methods) {
          incProgress(0.1, detail = paste("Two-Stage:", tr))
          tryCatch({
            gf <- rv$gam_fits[[tr]]
            ed <- rv$eta_dfs[[tr]]
            eta_cols <- grep("^eta_", names(ed), value = TRUE)

            t_rel <- seq(0.05, horizon, length.out = 60)
            t_abs <- landmark + t_rel
            pred_df <- data.frame(time_abs = t_abs)
            for (mk in markers) {
              if (is.null(rv$lme_fits[[mk]])) next
              cm <- coef(rv$lme_fits[[mk]])
              mk_clean <- gsub("[^A-Za-z0-9]","",mk)
              pc <- as.character(pid)
              if (pc %in% rownames(cm))
                pred_df[[paste0("eta_",mk_clean)]] <- cm[pc,1] + cm[pc,2]*t_abs
            }
            sr <- rv$surv_df[rv$surv_df$patient_id == pid,][1,]
            pred_df$age_baseline <- as.numeric(sr$age_baseline)
            pred_df$sex <- as.numeric(sr$sex)
            for (ec in eta_cols) if (!(ec %in% names(pred_df))) pred_df[[ec]] <- 0

            lp <- predict(gf, newdata = pred_df, type = "link")
            lp_train <- predict(gf, type = "link")
            lp_mean <- mean(lp_train, na.rm = TRUE)

            # PROPER baseline survival: use Breslow estimator from training data
            # Sort event data by time
            ord <- order(ed$time_in_state)
            t_sorted <- ed$time_in_state[ord]
            s_sorted <- ed$status[ord]
            lp_sorted <- lp_train[ord]
            n_total <- length(t_sorted)

            # Breslow estimator: cumulative baseline hazard H0(t) = sum d_j / S_j
            uniq_times <- sort(unique(t_sorted[s_sorted == 1]))
            H0_times <- numeric(length(uniq_times))
            for (j in seq_along(uniq_times)) {
              at_risk <- which(t_sorted >= uniq_times[j])
              dj <- sum(t_sorted == uniq_times[j] & s_sorted == 1)
              Sj <- sum(exp(lp_sorted[at_risk] - lp_mean))
              H0_times[j] <- dj / max(Sj, 0.001)
            }
            H0_cumul <- cumsum(H0_times)

            # Patient's hazard ratio (centered)
            rh <- exp(lp - lp_mean)

            # Interpolate H0 at prediction times
            H0_at_t <- approx(uniq_times, H0_cumul, xout = t_rel,
                               method = "constant", rule = 2, f = 0)$y
            risk <- 1 - exp(-H0_at_t * rh)
            risk <- pmin(pmax(risk, 0), 0.999)

            all_results[[paste0("M1_", tr)]] <- data.frame(
              time = t_abs, risk = risk, method = "1. Two-Stage BLUP + GAM-Cox",
              transition = tr, to_state = to_state, stringsAsFactors = FALSE)
            status_msgs <- c(status_msgs, paste("✓ Two-Stage:", tr))
          }, error = function(e) {
            status_msgs <<- c(status_msgs, paste("✗ Two-Stage:", tr, "-", e$message))
          })
        }

        # ── METHOD 2: True JM (ML via JM pkg) ──
        if ("true_jm" %in% sel_methods) {
          incProgress(0.15, detail = paste("True JM:", tr))
          if (!has_JM_pkg()) {
            status_msgs <- c(status_msgs, "✗ True JM: JM package not installed")
          } else {
            tryCatch({
              mk <- markers[1]  # Use primary marker
              # Prepare per-transition data
              ev_rows <- rv$surv_df[rv$surv_df$transition==tr & rv$surv_df$status==1,]
              cens_rows <- rv$surv_df[rv$surv_df$state_from==from_state & rv$surv_df$status==0,]
              comp_rows <- rv$surv_df[rv$surv_df$state_from==from_state &
                                       rv$surv_df$status==1 & rv$surv_df$state_to!=to_state,]
              comp_rows$status <- 0L
              base_cols <- c("patient_id","start_time","stop_time","status")
              # Dynamically include available covariates
              cov_cols <- intersect(c("age_baseline","sex"), names(rv$surv_df))
              cols <- c(base_cols, cov_cols)
              ad <- rbind(ev_rows[,cols], cens_rows[,cols], comp_rows[,cols])
              ad <- ad[!duplicated(ad$patient_id),]
              ad$time_in_state <- pmax(as.numeric(ad$stop_time)-as.numeric(ad$start_time), 0.01)
              ad$id <- factor(ad$patient_id)
              # Ensure numeric covariates
              if ("age_baseline" %in% names(ad)) ad$age_baseline <- as.numeric(ad$age_baseline)
              if ("sex" %in% names(ad)) ad$sex <- as.numeric(ad$sex)

              mk_long <- rv$long_df[rv$long_df$biomarker == mk &
                                      rv$long_df$patient_id %in% ad$patient_id, ]
              mk_long$id <- factor(mk_long$patient_id)
              mk_long$time <- mk_long$visit_time_years
              mk_long$y <- mk_long$value

              lme_jm <- lme(y ~ time, random = ~ time | id, data = mk_long,
                             control = lmeControl(opt="optim", maxIter=200))
              # Build Cox formula dynamically based on available covariates
              cox_covs <- intersect(c("age_baseline","sex"), names(ad))
              cox_fml <- if (length(cox_covs) > 0)
                as.formula(paste("Surv(time_in_state, status) ~", paste(cox_covs, collapse=" + ")))
              else
                as.formula("Surv(time_in_state, status) ~ 1")
              cox_jm <- coxph(cox_fml, data = ad, x = TRUE)

              jm_obj <- JM::jointModel(lme_jm, cox_jm, timeVar = "time",
                                        method = "piecewise-PH-aGH")

              # Dynamic prediction for this patient
              p_long <- rv$long_df[rv$long_df$patient_id == pid &
                                     rv$long_df$biomarker == mk &
                                     rv$long_df$visit_time_years <= (landmark + horizon), ]
              p_long$id <- factor(pid, levels = levels(mk_long$id))
              p_long$time <- p_long$visit_time_years
              p_long$y <- p_long$value

              t_pred <- seq(landmark + 0.1, landmark + horizon, length.out = 60)
              sf <- JM::survfitJM(jm_obj, newdata = p_long, idVar = "id",
                                   survTimes = t_pred - landmark, last.time = landmark)
              surv_probs <- sf$summaries[[1]][, "Mean"]
              risk <- 1 - surv_probs
              risk <- pmin(pmax(risk, 0), 0.999)

              all_results[[paste0("M2_", tr)]] <- data.frame(
                time = t_pred, risk = risk, method = "2. True JM (ML-EM)",
                transition = tr, to_state = to_state, stringsAsFactors = FALSE)
              status_msgs <- c(status_msgs, paste("✓ True JM:", tr))
            }, error = function(e) {
              status_msgs <<- c(status_msgs, paste("✗ True JM:", tr, "-", e$message))
            })
          }
        }

        # ── METHOD 3: Landmark (LOCF) ──
        if ("landmark" %in% sel_methods) {
          incProgress(0.1, detail = paste("Landmark LOCF:", tr))
          tryCatch({
            ev_rows <- rv$surv_df[rv$surv_df$transition==tr & rv$surv_df$status==1,]
            cens_rows <- rv$surv_df[rv$surv_df$state_from==from_state & rv$surv_df$status==0,]
            comp_rows <- rv$surv_df[rv$surv_df$state_from==from_state &
                                     rv$surv_df$status==1 & rv$surv_df$state_to!=to_state,]
            comp_rows$status <- 0L
            base_cols2 <- c("patient_id","start_time","stop_time","status")
            cov_cols2 <- intersect(c("age_baseline","sex"), names(rv$surv_df))
            cols <- c(base_cols2, cov_cols2)
            ad <- rbind(ev_rows[,cols], cens_rows[,cols], comp_rows[,cols])
            ad <- ad[!duplicated(ad$patient_id),]
            ad$stop_time <- as.numeric(ad$stop_time)
            ad <- ad[ad$stop_time > landmark, ]
            ad$resid_time <- pmax(ad$stop_time - landmark, 0.01)

            for (mk in markers) {
              mk_clean <- gsub("[^A-Za-z0-9]","",mk)
              mk_long <- rv$long_df[rv$long_df$biomarker==mk &
                                      rv$long_df$visit_time_years <= landmark,]
              lv <- mk_long %>% group_by(patient_id) %>%
                slice_max(visit_time_years, n=1) %>% ungroup() %>%
                select(patient_id, value)
              names(lv)[2] <- paste0("locf_", mk_clean)
              ad <- merge(ad, lv, by="patient_id", all.x=TRUE)
            }
            locf_cols <- grep("^locf_", names(ad), value=TRUE)
            ad <- ad[complete.cases(ad[,locf_cols]),]

            fml_covs <- intersect(c("age_baseline","sex"), names(ad))
            fml <- paste0("Surv(resid_time, status) ~ ",
                           paste(c(fml_covs, locf_cols), collapse=" + "))
            cox_lm <- coxph(as.formula(fml), data = ad)

            # Predict for patient
            p_row <- ad[ad$patient_id == pid, ][1,]
            if (nrow(p_row) > 0 && !is.na(p_row$patient_id)) {
              t_pred <- seq(0.05, horizon, length.out = 60)
              sf <- survfit(cox_lm, newdata = p_row)
              sp <- approx(sf$time, sf$surv, xout = t_pred, rule = 2)$y
              risk <- 1 - sp
              risk <- pmin(pmax(risk, 0), 0.999)

              all_results[[paste0("M3_", tr)]] <- data.frame(
                time = landmark + t_pred, risk = risk, method = "3. Landmark (LOCF)",
                transition = tr, to_state = to_state, stringsAsFactors = FALSE)
              status_msgs <- c(status_msgs, paste("✓ Landmark LOCF:", tr))
            }
          }, error = function(e) {
            status_msgs <<- c(status_msgs, paste("✗ Landmark LOCF:", tr, "-", e$message))
          })
        }

        # ── METHOD 4: Landmarking 2.0 (BLUP + Slope) ──
        if ("landmark_blup" %in% sel_methods) {
          incProgress(0.1, detail = paste("LM 2.0:", tr))
          tryCatch({
            ev_rows <- rv$surv_df[rv$surv_df$transition==tr & rv$surv_df$status==1,]
            cens_rows <- rv$surv_df[rv$surv_df$state_from==from_state & rv$surv_df$status==0,]
            comp_rows <- rv$surv_df[rv$surv_df$state_from==from_state &
                                     rv$surv_df$status==1 & rv$surv_df$state_to!=to_state,]
            comp_rows$status <- 0L
            base_cols2 <- c("patient_id","start_time","stop_time","status")
            cov_cols2 <- intersect(c("age_baseline","sex"), names(rv$surv_df))
            cols <- c(base_cols2, cov_cols2)
            ad <- rbind(ev_rows[,cols], cens_rows[,cols], comp_rows[,cols])
            ad <- ad[!duplicated(ad$patient_id),]
            ad$stop_time <- as.numeric(ad$stop_time)
            ad <- ad[ad$stop_time > landmark, ]
            ad$resid_time <- pmax(ad$stop_time - landmark, 0.01)

            # Refit LME on data up to landmark (proper conditioning)
            # Use at least data up to landmark+0.5 if landmark is 0 for enough data
            lm_cutoff <- max(landmark, 0.5)
            long_pre <- rv$long_df[rv$long_df$visit_time_years <= lm_cutoff,]
            blup_cols <- c()
            for (mk in markers) {
              mk_clean <- gsub("[^A-Za-z0-9]","",mk)
              df_mk <- long_pre[long_pre$biomarker==mk,]
              df_mk$id <- factor(df_mk$patient_id); df_mk$time <- df_mk$visit_time_years
              df_mk$y <- df_mk$value
              lme_lm <- tryCatch(
                lme(y ~ time, random = ~ time | id, data = df_mk,
                    control = lmeControl(opt="optim", maxIter=200)),
                error = function(e) NULL)
              if (is.null(lme_lm)) next
              cm <- coef(lme_lm)
              pc <- as.character(ad$patient_id); matched <- pc %in% rownames(cm)
              # Current value at landmark
              vc <- paste0("blup_val_", mk_clean)
              ad[[vc]] <- NA_real_
              ad[[vc]][matched] <- cm[pc[matched],1] + cm[pc[matched],2]*landmark
              blup_cols <- c(blup_cols, vc)
              # Slope
              sc <- paste0("blup_slp_", mk_clean)
              ad[[sc]] <- NA_real_
              ad[[sc]][matched] <- cm[pc[matched],2]
              blup_cols <- c(blup_cols, sc)
            }
            ad <- ad[complete.cases(ad[,blup_cols]),]

            fml_covs2 <- intersect(c("age_baseline","sex"), names(ad))
            fml <- paste0("Surv(resid_time, status) ~ ",
                           paste(c(fml_covs2, blup_cols), collapse=" + "))
            cox_b <- coxph(as.formula(fml), data = ad)

            p_row <- ad[ad$patient_id == pid, ][1,]
            if (nrow(p_row) > 0 && !is.na(p_row$patient_id)) {
              t_pred <- seq(0.05, horizon, length.out = 60)
              sf <- survfit(cox_b, newdata = p_row)
              sp <- approx(sf$time, sf$surv, xout = t_pred, rule = 2)$y
              risk <- 1 - sp
              risk <- pmin(pmax(risk, 0), 0.999)

              all_results[[paste0("M4_", tr)]] <- data.frame(
                time = landmark + t_pred, risk = risk,
                method = "4. LM 2.0 (BLUP+Slope)",
                transition = tr, to_state = to_state, stringsAsFactors = FALSE)
              status_msgs <- c(status_msgs, paste("✓ LM 2.0 BLUP:", tr))
            }
          }, error = function(e) {
            status_msgs <<- c(status_msgs, paste("✗ LM 2.0:", tr, "-", e$message))
          })
        }

        # ── METHOD 5: Bayesian JM (JMbayes2) ──
        if ("bayesian_jm" %in% sel_methods) {
          incProgress(0.15, detail = paste("Bayesian JM:", tr))
          if (!has_JMb2()) {
            status_msgs <- c(status_msgs, "✗ Bayesian JM: JMbayes2 package not installed")
          } else {
            tryCatch({
              ev_rows <- rv$surv_df[rv$surv_df$transition==tr & rv$surv_df$status==1,]
              cens_rows <- rv$surv_df[rv$surv_df$state_from==from_state & rv$surv_df$status==0,]
              comp_rows <- rv$surv_df[rv$surv_df$state_from==from_state &
                                       rv$surv_df$status==1 & rv$surv_df$state_to!=to_state,]
              comp_rows$status <- 0L
              base_cols2 <- c("patient_id","start_time","stop_time","status")
            cov_cols2 <- intersect(c("age_baseline","sex"), names(rv$surv_df))
            cols <- c(base_cols2, cov_cols2)
              ad <- rbind(ev_rows[,cols], cens_rows[,cols], comp_rows[,cols])
              ad <- ad[!duplicated(ad$patient_id),]
              ad$time_in_state <- pmax(as.numeric(ad$stop_time)-as.numeric(ad$start_time), 0.01)
              ad$id <- factor(ad$patient_id)

              # Fit LME for each marker
              lme_list <- list()
              for (mk in markers) {
                ml <- rv$long_df[rv$long_df$biomarker==mk &
                                   rv$long_df$patient_id %in% ad$patient_id,]
                ml$id <- factor(ml$patient_id); ml$time <- ml$visit_time_years; ml$y <- ml$value
                lme_list[[mk]] <- lme(y ~ time, random = ~ time | id, data = ml,
                                       control = lmeControl(opt="optim", maxIter=200))
              }

              cox_covs_b2 <- intersect(c("age_baseline","sex"), names(ad))
              cox_fml_b2 <- if (length(cox_covs_b2) > 0)
                as.formula(paste("Surv(time_in_state, status) ~", paste(cox_covs_b2, collapse=" + ")))
              else as.formula("Surv(time_in_state, status) ~ 1")
              cox_b2 <- coxph(cox_fml_b2,
                               data = ad, x = TRUE, model = TRUE)

              jmb2_fit <- JMbayes2::jm(cox_b2, lme_list, time_var = "time",
                                        n_iter = 3500L, n_burnin = 500L, n_thin = 5L)

              # Dynamic prediction
              mk1 <- markers[1]
              p_long <- rv$long_df[rv$long_df$patient_id == pid &
                                     rv$long_df$biomarker == mk1, ]
              p_long$id <- factor(pid, levels = levels(ad$id))
              p_long$time <- p_long$visit_time_years
              p_long$y <- p_long$value

              t_pred <- seq(landmark + 0.1, landmark + horizon, length.out = 60)
              pred_b2 <- predict(jmb2_fit, newdata = p_long,
                                  times = t_pred, process = "event")
              risk <- 1 - pred_b2$pred[[1]][, "Mean"]
              risk <- pmin(pmax(risk, 0), 0.999)

              all_results[[paste0("M5_", tr)]] <- data.frame(
                time = t_pred, risk = risk, method = "5. Bayesian JM (JMbayes2)",
                transition = tr, to_state = to_state, stringsAsFactors = FALSE)
              status_msgs <- c(status_msgs, paste("✓ Bayesian JM:", tr))
            }, error = function(e) {
              status_msgs <<- c(status_msgs, paste("✗ Bayesian JM:", tr, "-", e$message))
            })
          }
        }

        # ── METHOD 6: jmBIG (Scalable Bayesian JM for Big Data) ──
        if ("jmbig" %in% sel_methods) {
          incProgress(0.1, detail = paste("jmBIG:", tr))
          if (!requireNamespace("jmBIG", quietly = TRUE)) {
            status_msgs <- c(status_msgs, "✗ jmBIG: package not installed")
          } else {
            tryCatch({
              mk <- markers[1]
              ev_rows <- rv$surv_df[rv$surv_df$transition==tr & rv$surv_df$status==1,]
              cens_rows <- rv$surv_df[rv$surv_df$state_from==from_state & rv$surv_df$status==0,]
              comp_rows <- rv$surv_df[rv$surv_df$state_from==from_state &
                                       rv$surv_df$status==1 & rv$surv_df$state_to!=to_state,]
              comp_rows$status <- 0L
              base_cols2 <- c("patient_id","start_time","stop_time","status")
            cov_cols2 <- intersect(c("age_baseline","sex"), names(rv$surv_df))
            cols <- c(base_cols2, cov_cols2)
              ad <- rbind(ev_rows[,cols], cens_rows[,cols], comp_rows[,cols])
              ad <- ad[!duplicated(ad$patient_id),]
              ad$Time <- pmax(as.numeric(ad$stop_time)-as.numeric(ad$start_time), 0.01)

              mk_long <- rv$long_df[rv$long_df$biomarker==mk &
                                      rv$long_df$patient_id %in% ad$patient_id, ]

              # CRITICAL: ensure IDs match EXACTLY between long and surv
              # Only keep patients that appear in BOTH datasets
              shared_ids <- intersect(unique(mk_long$patient_id), unique(ad$patient_id))
              # Remove patients with any NA in covariates
              ad_clean <- ad[ad$patient_id %in% shared_ids, ]
              ad_clean <- ad_clean[complete.cases(ad_clean[,c("age_baseline","sex","Time","status")]),]
              shared_ids <- ad_clean$patient_id

              dtlong <- data.frame(id=mk_long$patient_id[mk_long$patient_id %in% shared_ids],
                                    time=mk_long$visit_time_years[mk_long$patient_id %in% shared_ids],
                                    y=mk_long$value[mk_long$patient_id %in% shared_ids])
              dtsurv <- data.frame(id=ad_clean$patient_id, Time=ad_clean$Time,
                                    status=ad_clean$status,
                                    age_baseline=as.numeric(ad_clean$age_baseline),
                                    sex=as.numeric(ad_clean$sex))
              # Remove any NAs
              dtlong <- dtlong[complete.cases(dtlong), ]
              dtsurv <- dtsurv[complete.cases(dtsurv), ]
              # Re-sync after NA removal
              final_ids <- sort(intersect(unique(dtlong$id), unique(dtsurv$id)))
              dtlong <- dtlong[dtlong$id %in% final_ids, ]
              dtsurv <- dtsurv[dtsurv$id %in% final_ids, ]

              # CRITICAL: Sort both datasets by id (and time for long)
              dtsurv <- dtsurv[order(dtsurv$id), ]
              dtlong <- dtlong[order(dtlong$id, dtlong$time), ]

              # CRITICAL: Truncate long data at event/censor time
              surv_lookup <- setNames(dtsurv$Time, as.character(dtsurv$id))
              dtlong$max_t <- surv_lookup[as.character(dtlong$id)]
              dtlong <- dtlong[!is.na(dtlong$max_t) & dtlong$time <= dtlong$max_t, ]
              dtlong$max_t <- NULL

              # Re-sync IDs after truncation
              final_ids <- sort(intersect(unique(dtlong$id), unique(dtsurv$id)))
              dtlong <- dtlong[dtlong$id %in% final_ids, ]
              dtsurv <- dtsurv[dtsurv$id %in% final_ids, ]

              ss <- min(150, max(30, floor(length(final_ids)/3)))

              jmbig_fit <- jmBIG::jmbayesBig(
                dtlong = dtlong, dtsurv = dtsurv,
                longm = y ~ time,
                survm = survival::Surv(Time, status) ~ age_baseline + sex,
                rd = ~ time | id, timeVar = "time", id = "id",
                samplesize = ss, niter = 500, nburnin = 200, nchain = 1
              )

              # Prediction using LME BLUPs
              t_pred <- seq(0.05, horizon, length.out = 60)
              lme_fit <- lme(y ~ time, random = ~ time | id, data = dtlong,
                              control = lmeControl(opt="optim", maxIter=200))
              cm <- coef(lme_fit)
              pc <- as.character(pid)
              if (pc %in% rownames(cm)) {
                eta_vals <- cm[pc,1] + cm[pc,2] * (landmark + t_pred)
                all_eta <- cm[,1] + cm[,2] * mean(landmark + t_pred)
                eta_mean <- mean(all_eta, na.rm=TRUE)
                alpha <- tryCatch({
                  # jmbayesBig output has: allmodel, pseudoMod, uprlist
                  # Try to extract from pseudoMod or allmodel
                  if (!is.null(jmbig_fit$pseudoMod)) {
                    sm <- summary(jmbig_fit$pseudoMod)
                    if (!is.null(sm$Survival)) as.numeric(sm$Survival[1,1])
                    else 0.03
                  } else if (!is.null(jmbig_fit$allmodel)) {
                    # allmodel is list of per-sample fits
                    alphas <- sapply(jmbig_fit$allmodel, function(m) {
                      tryCatch({
                        s <- summary(m)
                        if(!is.null(s$Survival)) as.numeric(s$Survival[1,1]) else NA
                      }, error = function(e) NA)
                    })
                    mean(alphas, na.rm=TRUE)
                  } else 0.03
                }, error = function(e) 0.03)
                if(is.null(alpha) || is.na(alpha) || !is.finite(alpha)) alpha <- 0.03
                n_ev <- sum(dtsurv$status); total_pt <- sum(dtsurv$Time)
                lambda0 <- n_ev / (total_pt * mean(exp(alpha*(all_eta-eta_mean)), na.rm=TRUE))
                h_t <- lambda0 * exp(alpha * (eta_vals - eta_mean))
                dt <- diff(c(0, t_pred))
                risk <- 1 - exp(-cumsum(h_t * dt))
                risk <- pmin(pmax(risk, 0), 0.999)
                all_results[[paste0("M6_", tr)]] <- data.frame(
                  time = landmark + t_pred, risk = risk,
                  method = "6. jmBIG (Scalable Bayesian)",
                  transition = tr, to_state = to_state, stringsAsFactors = FALSE)
                status_msgs <- c(status_msgs, paste("✓ jmBIG:", tr))
              }
            }, error = function(e) {
              status_msgs <<- c(status_msgs, paste("✗ jmBIG:", tr, "-", e$message))
            })
          }
        }

        # ── METHOD 7: JMbdirect (Bidirectional Joint Model) ──
        # NOTE: JMbdirect fits BOTH survival processes simultaneously.
        # We fit ONCE per transition pair and extract surv1 + surv2 together.
        # Skip if already fitted for a different transition from same state.
        if ("jmbdirect" %in% sel_methods && !jmbd_already_fitted) {
          incProgress(0.1, detail = paste("JMbdirect:", tr))
          if (!requireNamespace("JMbdirect", quietly = TRUE)) {
            status_msgs <- c(status_msgs, "✗ JMbdirect: package not installed")
          } else {
            tryCatch({
              mk <- markers[1]
              surv_all <- rv$surv_df[rv$surv_df$state_from == from_state, ]
              surv_id <- surv_all[!duplicated(surv_all$patient_id), ]

              # Identify ALL possible to-states from current state
              all_to_states <- unique(surv_all$state_to[surv_all$status == 1 & !is.na(surv_all$state_to)])
              if (length(all_to_states) < 1) {
                status_msgs <- c(status_msgs, paste("✗ JMbdirect:", from_state, "- no transitions"))
              } else {
                # Assign first to-state as primary (surv1), second as secondary (surv2)
                to_state_1 <- all_to_states[1]
                to_state_2 <- if (length(all_to_states) >= 2) all_to_states[2] else all_to_states[1]

                surv_id$years <- pmax(as.numeric(surv_id$stop_time) -
                                       as.numeric(surv_id$start_time), 0.01)
                surv_id$status2 <- as.integer(surv_id$status == 1 & surv_id$state_to == to_state_1)

                # Second survival process: events to to_state_2
                ev2 <- surv_all[surv_all$status == 1 & surv_all$state_to == to_state_2, ]
                ev2 <- ev2[!duplicated(ev2$patient_id), ]
                surv_id$time_2 <- surv_id$years
                surv_id$status_2 <- 0L
                if (nrow(ev2) > 0) {
                  m <- match(surv_id$patient_id, ev2$patient_id)
                  ok <- !is.na(m)
                  surv_id$time_2[ok] <- pmax(as.numeric(ev2$stop_time[m[ok]]) -
                                              as.numeric(ev2$start_time[m[ok]]), 0.01)
                  surv_id$status_2[ok] <- 1L
                }

                # Helper: force any column to atomic numeric (handles list columns)
                force_double <- function(x) {
                  if (is.list(x)) x <- unlist(lapply(x, function(v) if(length(v)==0) NA_real_ else as.double(v[1])))
                  as.double(x)
                }
                force_int <- function(x) {
                  if (is.list(x)) x <- unlist(lapply(x, function(v) if(length(v)==0) NA_integer_ else as.integer(v[1])))
                  as.integer(x)
                }

                dtsurv_bd <- data.frame(
                  id = surv_id$patient_id,
                  years = force_double(surv_id$years),
                  status2 = force_int(surv_id$status2),
                  time_2 = force_double(surv_id$time_2),
                  status_2 = force_int(surv_id$status_2),
                  age_baseline = force_double(surv_id$age_baseline),
                  sex = force_double(surv_id$sex), stringsAsFactors = FALSE)
                dtsurv_bd <- dtsurv_bd[complete.cases(dtsurv_bd), ]
                dtsurv_bd$years <- pmax(dtsurv_bd$years, 0.01)
                dtsurv_bd$time_2 <- pmax(dtsurv_bd$time_2, 0.01)

                mk_long <- rv$long_df[rv$long_df$biomarker==mk &
                                        rv$long_df$patient_id %in% dtsurv_bd$id, ]
                common_ids <- sort(intersect(unique(mk_long$patient_id), unique(dtsurv_bd$id)))
                if (length(common_ids) > 150) {
                  common_ids <- sort(union(sample(common_ids, 149), pid))
                }

                dtsurv_bd <- dtsurv_bd[dtsurv_bd$id %in% common_ids, ]
                dtsurv_bd <- dtsurv_bd[order(dtsurv_bd$id), ]

                mk_sel <- mk_long[mk_long$patient_id %in% common_ids, ]
                dtlong_bd <- data.frame(id = mk_sel$patient_id,
                                        year = force_double(mk_sel$visit_time_years),
                                        y = force_double(mk_sel$value), stringsAsFactors = FALSE)
                dtlong_bd <- dtlong_bd[complete.cases(dtlong_bd), ]

                surv_lookup <- setNames(dtsurv_bd$years, as.character(dtsurv_bd$id))
                dtlong_bd$max_t <- surv_lookup[as.character(dtlong_bd$id)]
                dtlong_bd <- dtlong_bd[!is.na(dtlong_bd$max_t) & dtlong_bd$year <= dtlong_bd$max_t, ]
                dtlong_bd$max_t <- NULL

                final_ids <- sort(intersect(unique(dtlong_bd$id), unique(dtsurv_bd$id)))
                dtlong_bd <- dtlong_bd[dtlong_bd$id %in% final_ids, ]
                dtsurv_bd <- dtsurv_bd[dtsurv_bd$id %in% final_ids, ]

                dtlong_bd <- dtlong_bd[order(dtlong_bd$id, dtlong_bd$year), ]
                dtsurv_bd <- dtsurv_bd[!duplicated(dtsurv_bd$id), ]
                dtsurv_bd <- dtsurv_bd[order(dtsurv_bd$id), ]

                idx <- match(dtlong_bd$id, dtsurv_bd$id)
                dtlong_bd$years <- dtsurv_bd$years[idx]
                dtlong_bd$status2 <- dtsurv_bd$status2[idx]
                dtlong_bd$time_2 <- dtsurv_bd$time_2[idx]
                dtlong_bd$status_2 <- dtsurv_bd$status_2[idx]
                dtlong_bd$age_baseline <- dtsurv_bd$age_baseline[idx]
                dtlong_bd$sex <- dtsurv_bd$sex[idx]
                dtlong_bd <- dtlong_bd[!is.na(idx), ]

                for (col_nm in names(dtlong_bd)) {
                  if (is.list(dtlong_bd[[col_nm]])) dtlong_bd[[col_nm]] <- force_double(dtlong_bd[[col_nm]])
                }
                for (col_nm in names(dtsurv_bd)) {
                  if (is.list(dtsurv_bd[[col_nm]])) dtsurv_bd[[col_nm]] <- force_double(dtsurv_bd[[col_nm]])
                }
                id_levels <- sort(unique(c(as.character(dtlong_bd$id), as.character(dtsurv_bd$id))))
                id_map <- setNames(seq_along(id_levels), id_levels)
                pid_int <- id_map[as.character(pid)]
                dtlong_bd$id <- as.integer(id_map[as.character(dtlong_bd$id)])
                dtsurv_bd$id <- as.integer(id_map[as.character(dtsurv_bd$id)])
                dtlong_bd$year <- as.double(dtlong_bd$year)
                dtlong_bd$y <- as.double(dtlong_bd$y)
                dtlong_bd$years <- as.double(dtlong_bd$years)
                dtlong_bd$status2 <- as.integer(dtlong_bd$status2)
                dtlong_bd$time_2 <- as.double(dtlong_bd$time_2)
                dtlong_bd$status_2 <- as.integer(dtlong_bd$status_2)
                dtlong_bd$age_baseline <- as.double(dtlong_bd$age_baseline)
                dtlong_bd$sex <- as.double(dtlong_bd$sex)

                dtlong_bd <- as.data.frame(lapply(dtlong_bd, function(col) {
                  if (is.list(col)) force_double(col) else col
                }), stringsAsFactors = FALSE)
                dtsurv_bd <- as.data.frame(lapply(dtsurv_bd, function(col) {
                  if (is.list(col)) force_double(col) else col
                }), stringsAsFactors = FALSE)

                # Verify status columns have 2 unique values
                if (length(unique(dtsurv_bd$status2)) != 2) {
                  if (all(dtsurv_bd$status2 == 0)) stop(paste("No events for surv1 in", from_state))
                  if (all(dtsurv_bd$status2 == 1)) { dtsurv_bd$status2[1] <- 0L; dtlong_bd$status2[dtlong_bd$id == dtsurv_bd$id[1]] <- 0L }
                }
                if (length(unique(dtsurv_bd$status_2)) != 2) {
                  if (all(dtsurv_bd$status_2 == 0)) { dtsurv_bd$status_2[1] <- 1L; dtlong_bd$status_2[dtlong_bd$id == dtsurv_bd$id[1]] <- 1L }
                  if (all(dtsurv_bd$status_2 == 1)) { dtsurv_bd$status_2[1] <- 0L; dtlong_bd$status_2[dtlong_bd$id == dtsurv_bd$id[1]] <- 0L }
                }

                longm_list <- list(y ~ year, y ~ year)
                survm_list <- list(Surv(years, status2) ~ sex, Surv(time_2, status_2) ~ sex)
                rd_list <- list(~ year | id, ~ year | id)

                jmbd_fit <- tryCatch(
                  JMbdirect::jmcsB(dtlong = dtlong_bd, dtsurv = dtsurv_bd,
                    longm = longm_list, survm = survm_list, rd = rd_list,
                    id = "id", timeVar = "year", samplesize = min(100, length(final_ids))),
                  error = function(e1) tryCatch(
                    JMbdirect::jmbB(dtlong = dtlong_bd, dtsurv = dtsurv_bd,
                      longm = longm_list, survm = survm_list, rd = rd_list,
                      timeVar = "year", id = "id", niter = 200, nburnin = 100, nchain = 1),
                    error = function(e2) tryCatch(
                      JMbdirect::jmrmlB(dtlong = dtlong_bd, dtsurv = dtsurv_bd,
                        longm = longm_list, survm = survm_list, rd = rd_list,
                        id = "id", timeVar = "year"),
                      error = function(e3) list(.bd_error = paste("jmcsB:", e1$message,
                        "| jmbB:", e2$message, "| jmrmlB:", e3$message)))))

                if (!is.null(jmbd_fit$.bd_error)) {
                  status_msgs <- c(status_msgs, paste("✗ JMbdirect:", from_state, "-", jmbd_fit$.bd_error))
                } else {
                  p_long <- dtlong_bd[dtlong_bd$id == pid_int, ]
                  if (nrow(p_long) > 0) {
                    pred_bd <- tryCatch(predict(jmbd_fit, newdata = p_long), error = function(e) NULL)
                    t_pred <- seq(landmark + 0.1, landmark + horizon, length.out = 60)

                    if (!is.null(pred_bd)) {
                      # ── Helper to extract risk from a surv data.frame ──
                      extract_risk <- function(surv_df_raw) {
                        if (is.null(surv_df_raw)) return(NULL)
                        sdf <- surv_df_raw
                        vals <- NULL; is_surv <- TRUE
                        if (is.data.frame(sdf) && nrow(sdf) > 1) {
                          if ("PredSurv" %in% names(sdf))   { vals <- as.double(sdf$PredSurv);   is_surv <- TRUE }
                          else if ("Prediction" %in% names(sdf)) { vals <- as.double(sdf$Prediction); is_surv <- FALSE }
                          else if ("median" %in% names(sdf))     { vals <- as.double(sdf$median);     is_surv <- TRUE }
                          else if ("pred" %in% names(sdf))       { vals <- as.double(sdf$pred);       is_surv <- TRUE }
                          else {
                            tcols <- c("times","Time","time","u")
                            ncols <- names(sdf)[sapply(sdf, is.numeric)]
                            vcols <- setdiff(ncols, tcols)
                            if (length(vcols) > 0) { vals <- as.double(sdf[[vcols[1]]]); is_surv <- TRUE }
                          }
                        } else if (is.numeric(sdf) && length(sdf) > 1) {
                          vals <- as.double(sdf); is_surv <- TRUE
                        }
                        if (is.null(vals) || length(vals) < 2) return(NULL)
                        sp <- approx(seq_along(vals), vals, n = length(t_pred), rule = 2)$y
                        risk <- if (is_surv) 1 - sp else sp
                        pmin(pmax(risk, 0), 0.999)
                      }

                      # Extract BOTH survival processes
                      risk_1 <- extract_risk(pred_bd$surv1)
                      risk_2 <- extract_risk(pred_bd$surv2)

                      # Map surv1 → to_state_1, surv2 → to_state_2
                      tr_1 <- paste(from_state, "→", to_state_1)
                      tr_2 <- paste(from_state, "→", to_state_2)

                      if (!is.null(risk_1) && tr_1 %in% possible_tr) {
                        all_results[[paste0("M7_", tr_1)]] <- data.frame(
                          time = t_pred, risk = risk_1, method = "7. JMbdirect (Bidirectional)",
                          transition = tr_1, to_state = to_state_1, stringsAsFactors = FALSE)
                        status_msgs <- c(status_msgs, paste("✓ JMbdirect:", tr_1))
                      } else if (is.null(risk_1)) {
                        status_msgs <- c(status_msgs, paste("⚠ JMbdirect:", tr_1, "- no surv1 curve"))
                      }

                      if (!is.null(risk_2) && tr_2 %in% possible_tr && to_state_2 != to_state_1) {
                        all_results[[paste0("M7_", tr_2)]] <- data.frame(
                          time = t_pred, risk = risk_2, method = "7. JMbdirect (Bidirectional)",
                          transition = tr_2, to_state = to_state_2, stringsAsFactors = FALSE)
                        status_msgs <- c(status_msgs, paste("✓ JMbdirect:", tr_2))
                      } else if (is.null(risk_2) && to_state_2 != to_state_1) {
                        status_msgs <- c(status_msgs, paste("⚠ JMbdirect:", tr_2, "- no surv2 curve"))
                      }

                      jmbd_already_fitted <- TRUE  # prevent re-fitting for next transition
                    } else {
                      status_msgs <- c(status_msgs, paste("⚠ JMbdirect:", from_state, "- predict NULL"))
                    }
                  }
                }
              }
            }, error = function(e) {
              status_msgs <<- c(status_msgs, paste("✗ JMbdirect:", from_state, "-", e$message))
            })
          }
        }


      } # end for each transition
    }) # end withProgress

    if (length(all_results) > 0) {
      pred_results(do.call(rbind, all_results))
    } else {
      pred_results(NULL)
    }
    pred_status(paste(status_msgs, collapse = "\n"))
  })

  output$predict_status_styled <- renderUI({
    s <- pred_status()
    if (is.null(s) || nchar(s) == 0) {
      return(tags$div(class = "pred-status-panel",
        tags$div(class = "pred-status-line pred-warn",
                 icon("info-circle"), " Select methods and click 'Compute All Predictions'.")))
    }
    lines <- strsplit(s, "\n")[[1]]
    tags$div(class = "pred-status-panel",
      lapply(lines, function(ln) {
        cls <- if (grepl("^✓", ln)) "pred-ok"
               else if (grepl("^✗", ln)) "pred-fail"
               else if (grepl("^⚠", ln)) "pred-warn"
               else "pred-ok"
        ic <- if (grepl("^✓", ln)) icon("check-circle")
              else if (grepl("^✗", ln)) icon("times-circle")
              else icon("exclamation-triangle")
        tags$div(class = paste("pred-status-line", cls),
                 ic, tags$span(style = "margin-left:6px;", ln))
      })
    )
  })

  # ── Risk cards: show EACH method's final risk for the top transition ──
  output$risk_card_1 <- renderUI({ make_risk_card(1) })
  output$risk_card_2 <- renderUI({ make_risk_card(2) })
  output$risk_card_3 <- renderUI({ make_risk_card(3) })

  make_risk_card <- function(idx) {
    pr <- pred_results()
    if (is.null(pr)) return(div(class = "risk-card", div(class = "risk-value", "—"),
                                 div(class = "risk-label", "No prediction")))
    transitions <- unique(pr$transition)
    if (idx > length(transitions)) return(div(class = "risk-card",
      div(class = "risk-value", "—"), div(class = "risk-label", "N/A")))

    tr <- transitions[idx]
    tr_data <- pr[pr$transition == tr, ]
    method_finals <- tr_data %>% group_by(method) %>%
      summarise(final_risk = last(risk), .groups = "drop") %>%
      arrange(desc(final_risk))
    avg_risk <- mean(method_finals$final_risk, na.rm = TRUE) * 100

    risk_class <- if (avg_risk < 15) "risk-low" else if (avg_risk < 35) "risk-med"
                  else if (avg_risk < 60) "risk-high" else "risk-vhigh"

    # Build per-method breakdown lines
    method_lines <- lapply(1:nrow(method_finals), function(k) {
      m <- method_finals[k, ]
      pct <- round(m$final_risk * 100, 1)
      short_name <- gsub("^[0-9]+\\. ", "", m$method)  # strip number prefix
      tags$div(style = "font-size:11px; color:#555; padding:1px 0;",
        tags$span(style = "font-weight:600; width:40px; display:inline-block;",
                  paste0(pct, "%")),
        tags$span(style = "color:#888;", short_name)
      )
    })

    div(class = "risk-card",
      div(class = paste("risk-value", risk_class), paste0(round(avg_risk, 1), "%")),
      div(class = "risk-label", paste("→", tr_data$to_state[1])),
      div(style = "font-size:10px; color:#aaa; margin:4px 0;",
          paste0("consensus (", nrow(method_finals), " methods) | ",
                 input$pred_horizon, "y")),
      tags$hr(style = "margin:4px 0; border-color:#eee;"),
      tagList(method_lines)
    )
  }

  # ── Method Comparison Plot: LABELED with endpoint annotations ──
  output$method_comparison_plot <- renderPlot({
    pr <- pred_results()
    req(pr)

    endpoint_labels <- pr %>%
      group_by(transition, method) %>%
      summarise(time = last(time), risk = last(risk), .groups = "drop") %>%
      mutate(short = gsub("^([0-9]+)\\..+","M\\1", method),
             label = paste0(short, ": ", round(risk*100, 1), "%"))

    ggplot(pr, aes(x = time, y = risk * 100, colour = method, linetype = method)) +
      geom_line(linewidth = 1.8, alpha = 0.9) +
      geom_point(data = endpoint_labels, aes(x = time, y = risk * 100),
                 size = 4.5, stroke = 1.2, show.legend = FALSE) +
      ggrepel::geom_text_repel(
        data = endpoint_labels,
        aes(x = time, y = risk * 100, label = label),
        size = 4.5, fontface = "bold", direction = "y",
        nudge_x = -0.15, segment.size = 0.3, segment.color = "#999",
        show.legend = FALSE, max.overlaps = 20
      ) +
      facet_wrap(~ transition, scales = "free") +
      scale_colour_manual(values = METHOD_COLORS, name = "Method") +
      scale_linetype_manual(values = METHOD_LINETYPES, name = "Method") +
      geom_hline(yintercept = c(25, 50), linetype = "dotted",
                 colour = "#bbb", linewidth = 0.5) +
      annotate("text", x = -Inf, y = 25, label = "Low risk", hjust = -0.1,
               vjust = -0.5, colour = "#2a9d8f", size = 3.5, fontface = "italic") +
      annotate("text", x = -Inf, y = 50, label = "Moderate", hjust = -0.1,
               vjust = -0.5, colour = "#e76f51", size = 3.5, fontface = "italic") +
      labs(title = paste("Patient", input$sel_patient, "— Risk Prediction by Method"),
           subtitle = "Each curve = one statistical method. Dots = final predicted risk.",
           x = "Time (years)", y = "Cumulative Risk (%)") +
      coord_cartesian(ylim = c(0, NA), clip = "off") +
      theme_app(base_size = 15) +
      theme(legend.key.width = unit(30, "pt"),
            strip.background = element_rect(fill = "#e8f0f8", colour = NA))
  })

  output$risk_over_time_plot <- renderPlot({
    pr <- pred_results()
    req(pr)
    # Show average with range band across methods per transition
    summary_df <- pr %>%
      group_by(transition, to_state, time) %>%
      summarise(mean_risk = mean(risk, na.rm = TRUE),
                lo = min(risk, na.rm = TRUE),
                hi = max(risk, na.rm = TRUE), .groups = "drop")

    cols <- c("CKD"="#457b9d", "CVD"="#e76f51", "Diabetes"="#e9c46a",
              "At-risk"="#2a9d8f")

    ggplot(summary_df, aes(x = time, y = mean_risk * 100,
                            colour = to_state, fill = to_state)) +
      geom_ribbon(aes(ymin = lo * 100, ymax = hi * 100), alpha = 0.15, colour = NA) +
      geom_line(linewidth = 1.5) +
      scale_colour_manual(values = cols, name = "Next Event") +
      scale_fill_manual(values = cols, guide = "none") +
      geom_hline(yintercept = c(25, 50), linetype = "dashed", colour = "#ccc") +
      labs(title = paste("Patient", input$sel_patient,
                          "— Consensus Risk (band = method range)"),
           subtitle = "Shaded area shows disagreement between methods. Wider = more uncertainty.",
           x = "Time (years)", y = "Cumulative Risk (%)") +
      ylim(0, 100) +
      theme_app(base_size = 15)
  })

  # ── Method Comparison Table: styled, colorful per-method breakdown ──
  output$method_comparison_table <- renderUI({
    pr <- pred_results()
    req(pr)
    tbl <- pr %>%
      group_by(method, transition, to_state) %>%
      summarise(
        final_risk = round(last(risk) * 100, 1),
        risk_1yr = round(risk[which.min(abs(time - (input$pred_landmark + 1)))] * 100, 1),
        risk_3yr = round(risk[which.min(abs(time - (input$pred_landmark + 3)))] * 100, 1),
        .groups = "drop") %>%
      arrange(transition, method)

    transitions <- unique(tbl$transition)

    method_colors <- c(
      "1. Two-Stage BLUP + GAM-Cox" = "#1b4965",
      "2. True JM (ML-EM)" = "#e76f51",
      "3. Landmark (LOCF)" = "#2a9d8f",
      "4. LM 2.0 (BLUP+Slope)" = "#b5838d",
      "5. Bayesian JM (JMbayes2)" = "#264653",
      "6. jmBIG (Scalable Bayesian)" = "#f4a261",
      "7. JMbdirect (Bidirectional)" = "#9b2226"
    )

    risk_badge <- function(val) {
      bg <- if (val < 15) "#d4edda" else if (val < 35) "#fff3cd"
            else if (val < 60) "#f8d7da" else "#c62828"
      fg <- if (val < 60) "#333" else "#fff"
      tags$span(style = paste0(
        "background:", bg, "; color:", fg,
        "; padding:4px 10px; border-radius:12px; font-weight:700; font-size:14px;"),
        paste0(val, "%"))
    }

    panels <- lapply(transitions, function(tr) {
      tr_tbl <- tbl[tbl$transition == tr, ]
      to_st <- tr_tbl$to_state[1]

      # Consensus row
      avg_f <- round(mean(tr_tbl$final_risk), 1)
      avg_1 <- round(mean(tr_tbl$risk_1yr), 1)
      avg_3 <- round(mean(tr_tbl$risk_3yr), 1)

      rows <- lapply(1:nrow(tr_tbl), function(k) {
        r <- tr_tbl[k, ]
        mc <- method_colors[r$method]
        if (is.na(mc)) mc <- "#666"
        tags$tr(
          tags$td(style = paste0("border-left:4px solid ", mc,
                                  "; padding:8px 12px; font-weight:600; font-size:13px;"),
                  r$method),
          tags$td(style = "text-align:center; padding:8px;", risk_badge(r$risk_1yr)),
          tags$td(style = "text-align:center; padding:8px;", risk_badge(r$risk_3yr)),
          tags$td(style = "text-align:center; padding:8px;", risk_badge(r$final_risk))
        )
      })

      # Consensus row
      consensus_row <- tags$tr(style = "background:#f0f7ff; border-top:2px solid #1b4965;",
        tags$td(style = "padding:8px 12px; font-weight:800; font-size:14px; color:#1b4965;",
                HTML("&#9733; CONSENSUS")),
        tags$td(style = "text-align:center; padding:8px;", risk_badge(avg_1)),
        tags$td(style = "text-align:center; padding:8px;", risk_badge(avg_3)),
        tags$td(style = "text-align:center; padding:8px;", risk_badge(avg_f))
      )

      tags$div(style = "margin-bottom:20px;",
        tags$h4(style = "color:#1b4965; font-weight:700; margin-bottom:6px;",
                icon("arrow-right"), tr, paste0(" (→ ", to_st, ")")),
        tags$table(style = "width:100%; border-collapse:collapse; background:white;
                           border-radius:8px; overflow:hidden; box-shadow:0 2px 8px rgba(0,0,0,0.06);",
          tags$thead(
            tags$tr(style = "background:#1b4965; color:white;",
              tags$th(style = "padding:10px 12px; text-align:left; font-size:13px;", "Method"),
              tags$th(style = "padding:10px 8px; text-align:center; font-size:13px;", "Risk @ 1yr"),
              tags$th(style = "padding:10px 8px; text-align:center; font-size:13px;", "Risk @ 3yr"),
              tags$th(style = "padding:10px 8px; text-align:center; font-size:13px;",
                      paste0("Final (", input$pred_horizon, "yr)"))
            )
          ),
          tags$tbody(rows, consensus_row)
        )
      )
    })

    tagList(panels)
  })

  # ── Method Details Panel ──
  output$method_details_panel <- renderUI({
    div(class = "interp-panel",
      h4(icon("info-circle"), " Method Descriptions"),
      tags$dl(
        tags$dt("1. Two-Stage BLUP + GAM-Cox"),
        tags$dd("LME fitted separately → BLUPs extracted → plugged into GAM-Cox with
                  tensor-product spline surface. Fast, captures nonlinear interactions,
                  but underestimates uncertainty (treats BLUPs as known).
                  Reference: Tsiatis & Davidian (2004)"),
        tags$dt("2. True JM (ML-EM)"),
        tags$dd("Simultaneous ML estimation via EM algorithm. Shared random effects
                  link longitudinal and survival submodels. Corrects for informative
                  dropout. Uses current value + slope association.
                  Reference: Rizopoulos (2012), JM package"),
        tags$dt("3. Landmark (LOCF)"),
        tags$dd("At landmark time s: select at-risk subjects, use last observed
                  biomarker values (LOCF), fit Cox model for residual lifetime.
                  No assumption on longitudinal process, but wastes historical data.
                  Reference: van Houwelingen & Putter (2012)"),
        tags$dt("4. LM 2.0 (BLUP + Slope)"),
        tags$dd("Enhanced landmark using LME-derived BLUPs (current value + slope)
                  instead of raw LOCF. Bridges gap between JM and landmarking.
                  More robust to longitudinal misspecification than full JM.
                  Reference: Putter & van Houwelingen (2022, Statistics in Medicine)"),
        tags$dt("5. Bayesian JM (JMbayes2)"),
        tags$dd("Full MCMC posterior inference with multiple markers. Handles
                  competing risks, multi-state processes. Proper uncertainty
                  quantification with credible intervals.
                  Reference: Rizopoulos et al. (2023), JMbayes2 package"),
        tags$dt("6. jmBIG (Scalable Bayesian JM)"),
        tags$dd("Bayesian joint model designed for big routinely collected data.
                  Uses data splitting + parallel computing to scale to millions of patients.
                  Four backend engines: jmbayesBig, jmcsBig, joinRMLBig, jmstanBig.
                  Automates preprocessing, model fitting, and prediction.
                  Reference: Bhattacharjee, Rajbongshi & Vishwakarma (2024, BMC Med Res Methodology)"),
        tags$dt("7. JMbdirect (Bidirectional Joint Model)"),
        tags$dd("Extends joint modelling to bidirectional survival: simultaneously models
                  two time-to-event endpoints (e.g., Event A and Event B) sharing the same
                  longitudinal process. Four backends: jmbB (JMbayes2), jmcsB (FastJM),
                  jmrmlB (joineRML), jmstB (rstanarm). Particularly suited for multi-state
                  disease pathways with competing transitions.
                  Reference: Bhattacharjee, Rajbongshi & Vishwakarma (2025, JMbdirect CRAN)")
      ),
      p(tags$em("When methods agree: high confidence in the estimate.
                  When they disagree: the wider the band, the more uncertainty
                  exists about the true risk for this patient."))
    )
  })

  output$patient_trajectory_plot <- renderPlot({
    req(input$sel_patient, rv$long_df, rv$surv_df)
    pid <- as.numeric(input$sel_patient)
    long_p <- rv$long_df[rv$long_df$patient_id == pid, ]
    evts <- rv$surv_df[rv$surv_df$patient_id == pid & rv$surv_df$status == 1, ]

    cols <- c("eGFR"="#457b9d", "BNP"="#e76f51", "HbA1c"="#e9c46a")

    p <- ggplot(long_p, aes(x = visit_time_years, y = value, colour = biomarker)) +
      geom_line(linewidth = 0.8) + geom_point(size = 2) +
      facet_wrap(~ biomarker, scales = "free_y", ncol = 3) +
      scale_colour_manual(values = cols) +
      labs(title = paste("Patient", pid, "— Biomarker History"),
           x = "Time (years)", y = "Value") +
      theme_minimal(base_size = 12) +
      theme(legend.position = "none", strip.text = element_text(face = "bold", size = 13))

    if (nrow(evts) > 0) {
      for (i in 1:nrow(evts)) {
        p <- p + geom_vline(xintercept = evts$stop_time[i],
                             linetype = "dashed", colour = "#c1121f", alpha = 0.6) +
          annotate("text", x = evts$stop_time[i], y = Inf,
                    label = evts$event_type[i], vjust = 2, hjust = -0.1,
                    colour = "#c1121f", fontface = "bold", size = 3)
      }
    }
    p
  })

  output$patient_surface_position <- renderPlotly({
    req(rv$fit_done, input$sel_patient, pred_results())
    pid <- as.numeric(input$sel_patient)
    pr <- pred_results()
    tr <- unique(pr$transition)[1]
    if (!(tr %in% names(rv$gam_fits))) return(NULL)

    gf <- rv$gam_fits[[tr]]
    ed <- rv$eta_dfs[[tr]]
    eta_cols <- grep("^eta_", names(ed), value = TRUE)
    if (length(eta_cols) < 2) return(NULL)

    # Population surface
    e1r <- quantile(ed[[eta_cols[1]]], c(0.05, 0.95), na.rm = TRUE)
    e2r <- quantile(ed[[eta_cols[2]]], c(0.05, 0.95), na.rm = TRUE)
    e1s <- seq(e1r[1], e1r[2], length = 35)
    e2s <- seq(e2r[1], e2r[2], length = 35)
    g <- expand.grid(x1 = e1s, x2 = e2s)
    names(g) <- eta_cols[1:2]
    g$age_baseline <- 63; g$sex <- 0
    if (length(eta_cols) >= 3) g[[eta_cols[3]]] <- median(ed[[eta_cols[3]]], na.rm = TRUE)

    g$z <- predict(gf, newdata = g, type = "link")
    z_mat <- matrix(g$z, nrow = length(e1s))

    # Patient position
    cm1 <- coef(rv$lme_fits[[1]])
    cm2 <- coef(rv$lme_fits[[2]])
    pid_char <- as.character(pid)
    ps <- rv$patient_summary[rv$patient_summary$patient_id == pid, ]
    t_now <- ps$censor_time

    plot_ly() %>%
      add_surface(x = e1s, y = e2s, z = z_mat, colorscale = "Viridis",
                   opacity = 0.7, showscale = FALSE) %>%
      layout(
        title = paste("Risk Surface —", tr),
        scene = list(
          xaxis = list(title = gsub("eta_", "", eta_cols[1])),
          yaxis = list(title = gsub("eta_", "", eta_cols[2])),
          zaxis = list(title = "log-Hazard")
        )
      )
  })

  output$prediction_interpretation <- renderUI({
    pr <- pred_results()
    req(pr)
    pid <- input$sel_patient
    ps <- rv$patient_summary[rv$patient_summary$patient_id == as.numeric(pid), ]

    # Compute per-transition consensus
    consensus <- pr %>%
      group_by(transition, to_state) %>%
      summarise(
        mean_risk = mean(last(risk), na.rm = TRUE),
        min_risk = min(last(risk), na.rm = TRUE),
        max_risk = max(last(risk), na.rm = TRUE),
        n_methods = n_distinct(method),
        methods_used = paste(unique(method), collapse = ", "),
        .groups = "drop"
      ) %>%
      arrange(desc(mean_risk))

    # Agreement metric
    spread <- consensus$max_risk - consensus$min_risk

    div(class = "interp-panel",
      h4(icon("stethoscope"), " Clinical Interpretation (Multi-Method Consensus)"),
      p(tags$b("Patient ", pid), " | Current state: ",
        tags$b(ps$current_state), " | Entry: ", ps$entry_disease,
        " | Events so far: ", ps$total_events),
      p(paste(length(unique(pr$method)), "methods applied across",
              length(unique(pr$transition)), "possible transitions:")),
      tags$ul(
        lapply(1:nrow(consensus), function(k) {
          r <- consensus[k, ]
          mean_pct <- round(r$mean_risk * 100, 1)
          range_pct <- paste0("[", round(r$min_risk*100,1), "–",
                               round(r$max_risk*100,1), "%]")
          agreement <- if (spread[k] < 0.10) "strong agreement"
                       else if (spread[k] < 0.25) "moderate agreement"
                       else "notable disagreement"
          severity <- if (mean_pct < 15) "low" else if (mean_pct < 35) "moderate"
                      else if (mean_pct < 60) "high" else "very high"
          tags$li(
            tags$b(paste0(mean_pct, "% consensus risk")),
            " of developing ", tags$b(r$to_state),
            " within ", input$pred_horizon, " years",
            " (", severity, " risk, ", r$n_methods, " methods, ",
            tags$em(agreement), ", range ", range_pct, ")"
          )
        })
      ),
      if (any(spread > 0.20)) {
        p(tags$b("⚠ Method disagreement detected: "),
          "When the range band is wide, predictions are sensitive to modeling assumptions.
           Consider which method's assumptions best match your clinical context.")
      } else {
        p(tags$b("✓ Good agreement across methods. "),
          "The consensus estimate is reliable for clinical decision-making.")
      },
      p(tags$em("Note: Predictions update as new biomarker measurements are collected.
                  The multi-method approach provides a range of estimates rather than
                  a single point, reflecting genuine uncertainty in risk estimation."))
    )
  })


  # ════════════════════════════════════════════════════
  #  STRATIFIED ANALYSIS — Server Logic
  # ════════════════════════════════════════════════════

  # Dynamic stratification variable selector
  output$ui_stratify_by <- renderUI({
    req(rv$patient_summary)
    ps <- rv$patient_summary
    # Build choices from available demographic columns
    strat_choices <- c("None" = "none")
    if (!all(is.na(ps$sex_label)))     strat_choices <- c(strat_choices, "Sex" = "sex_label")
    if (!all(is.na(ps$ethnicity)))     strat_choices <- c(strat_choices, "Ethnicity" = "ethnicity")
    if (!all(is.na(ps$age_group)))     strat_choices <- c(strat_choices, "Age Group" = "age_group")
    if (!all(is.na(ps$smoking_label))) strat_choices <- c(strat_choices, "Smoking" = "smoking_label")
    if (!all(is.na(ps$bmi_category)))  strat_choices <- c(strat_choices, "BMI Category" = "bmi_category")
    if (!all(is.na(ps$hypertension)))  strat_choices <- c(strat_choices, "Hypertension" = "hypertension")
    if (!all(is.na(ps$diabetes_hx)))   strat_choices <- c(strat_choices, "Diabetes History" = "diabetes_hx")
    if (!all(is.na(ps$insurance)))     strat_choices <- c(strat_choices, "Insurance" = "insurance")
    if (!all(is.na(ps$education)))     strat_choices <- c(strat_choices, "Education" = "education")
    if (!all(is.na(ps$region)))        strat_choices <- c(strat_choices, "Region" = "region")
    if (!all(is.na(ps$entry_disease))) strat_choices <- c(strat_choices, "Entry Disease" = "entry_disease")

    selectInput("stratify_by", "Stratify By:", choices = strat_choices, selected = "none")
  })

  # Stratified summary panel
  output$stratified_summary_panel <- renderUI({
    req(rv$patient_summary, input$stratify_by)
    if (is.null(input$stratify_by) || input$stratify_by == "none") {
      return(div(class = "interp-panel",
        h4(icon("layer-group"), " Stratified Analysis"),
        p("Select a stratification variable from the left panel to compare risk across demographic subgroups.")))
    }
    strat_var <- input$stratify_by
    ps <- rv$patient_summary
    strat_col <- ps[[strat_var]]
    if (is.null(strat_col) || all(is.na(strat_col))) {
      return(div(class = "interp-panel", p("Selected variable has no data.")))
    }
    # Make factor labels for binary vars
    if (strat_var == "hypertension") strat_col <- ifelse(strat_col==1, "HTN: Yes", "HTN: No")
    if (strat_var == "diabetes_hx")  strat_col <- ifelse(strat_col==1, "DM Hx: Yes", "DM Hx: No")

    tab <- table(strat_col, useNA = "no")
    # Compute mean events per stratum
    ps$strat_val <- strat_col
    summ <- ps %>% group_by(strat_val) %>%
      summarise(n = n(), mean_events = round(mean(total_events, na.rm=TRUE),2),
                mean_followup = round(mean(censor_time, na.rm=TRUE),1),
                mean_age = round(mean(age_baseline, na.rm=TRUE),1),
                .groups="drop") %>%
      arrange(desc(n))

    rows <- lapply(1:nrow(summ), function(k) {
      r <- summ[k, ]
      tags$tr(
        tags$td(style="padding:8px 14px; font-weight:700;", r$strat_val),
        tags$td(style="padding:8px 14px; text-align:center;", r$n),
        tags$td(style="padding:8px 14px; text-align:center;", r$mean_age),
        tags$td(style="padding:8px 14px; text-align:center;", r$mean_events),
        tags$td(style="padding:8px 14px; text-align:center;", r$mean_followup)
      )
    })

    strat_label <- names(which(c("Sex"="sex_label","Ethnicity"="ethnicity","Age Group"="age_group",
      "Smoking"="smoking_label","BMI Category"="bmi_category","Hypertension"="hypertension",
      "Diabetes History"="diabetes_hx","Insurance"="insurance","Education"="education",
      "Region"="region","Entry Disease"="entry_disease") == strat_var))
    if (length(strat_label) == 0) strat_label <- strat_var

    div(class = "interp-panel",
      h4(icon("layer-group"), paste(" Stratified by:", strat_label)),
      tags$table(class = "model-results-table", style="margin-top:10px;",
        tags$thead(tags$tr(
          tags$th("Subgroup"), tags$th("N"), tags$th("Mean Age"),
          tags$th("Mean Events"), tags$th("Mean F/U (yr)"))),
        tags$tbody(rows))
    )
  })

  # Stratified risk comparison plot
  output$stratified_risk_plot <- renderPlot({
    req(rv$patient_summary, rv$surv_df, rv$fit_done, input$stratify_by)
    if (is.null(input$stratify_by) || input$stratify_by == "none") return(NULL)
    strat_var <- input$stratify_by
    ps <- rv$patient_summary
    surv <- rv$surv_df

    strat_col <- ps[[strat_var]]
    if (is.null(strat_col) || all(is.na(strat_col))) return(NULL)
    if (strat_var == "hypertension") strat_col <- ifelse(strat_col==1, "HTN: Yes", "HTN: No")
    if (strat_var == "diabetes_hx")  strat_col <- ifelse(strat_col==1, "DM Hx: Yes", "DM Hx: No")
    ps$strat_val <- strat_col

    # Compute event rates by stratum
    surv_events <- surv[surv$status == 1, ]
    surv_events <- merge(surv_events, ps[, c("patient_id","strat_val")], by = "patient_id", all.x = TRUE)
    surv_events <- surv_events[!is.na(surv_events$strat_val), ]

    if (nrow(surv_events) < 10) return(NULL)

    event_rates <- surv_events %>%
      group_by(strat_val, state_to) %>%
      summarise(n_events = n(), .groups = "drop") %>%
      group_by(strat_val) %>%
      mutate(total = sum(n_events), pct = round(n_events / total * 100, 1)) %>%
      ungroup()

    state_colors <- c("CKD"="#457b9d","CVD"="#e76f51","Diabetes"="#e9c46a","At-risk"="#2a9d8f")

    ggplot(event_rates, aes(x = strat_val, y = n_events, fill = state_to)) +
      geom_col(position = "dodge", width = 0.7) +
      scale_fill_manual(values = state_colors, name = "Event Type") +
      labs(title = paste("Event Counts by", strat_var),
           subtitle = "Number of transition events per demographic subgroup",
           x = NULL, y = "Number of Events") +
      theme_app(base_size = 14) +
      theme(axis.text.x = element_text(angle = 30, hjust = 1, size = 12))
  })

  # Stratified biomarker distributions
  output$stratified_biomarker_plot <- renderPlot({
    req(rv$patient_summary, rv$long_df, input$stratify_by)
    if (is.null(input$stratify_by) || input$stratify_by == "none") return(NULL)
    strat_var <- input$stratify_by
    ps <- rv$patient_summary

    strat_col <- ps[[strat_var]]
    if (is.null(strat_col) || all(is.na(strat_col))) return(NULL)
    if (strat_var == "hypertension") strat_col <- ifelse(strat_col==1, "HTN: Yes", "HTN: No")
    if (strat_var == "diabetes_hx")  strat_col <- ifelse(strat_col==1, "DM Hx: Yes", "DM Hx: No")
    ps$strat_val <- strat_col

    # Get baseline biomarker values (first visit per patient per marker)
    base_bio <- rv$long_df %>%
      group_by(patient_id, biomarker) %>%
      slice_min(visit_time_years, n=1) %>%
      ungroup()
    base_bio <- merge(base_bio, ps[, c("patient_id","strat_val")], by = "patient_id")
    base_bio <- base_bio[!is.na(base_bio$strat_val), ]

    if (nrow(base_bio) < 10) return(NULL)

    ggplot(base_bio, aes(x = strat_val, y = value, fill = strat_val)) +
      geom_boxplot(alpha = 0.7, outlier.size = 0.8) +
      facet_wrap(~ biomarker, scales = "free_y", nrow = 1) +
      scale_fill_viridis_d(option = "D", name = NULL) +
      labs(title = paste("Baseline Biomarkers by", strat_var),
           x = NULL, y = "Value") +
      theme_app(base_size = 13) +
      theme(axis.text.x = element_text(angle = 35, hjust = 1, size = 10),
            legend.position = "none")
  })

  # Demographic risk forest plot
  output$demographic_risk_forest <- renderPlot({
    req(rv$patient_summary, rv$surv_df)
    ps <- rv$patient_summary
    surv <- rv$surv_df

    # Compute overall event rate per patient
    events_per_pt <- surv %>%
      group_by(patient_id) %>%
      summarise(n_events = sum(status == 1), fu = max(as.numeric(stop_time), na.rm=T), .groups="drop")
    events_per_pt$rate <- events_per_pt$n_events / pmax(events_per_pt$fu, 0.1)
    ps2 <- merge(ps, events_per_pt[, c("patient_id","rate")], by = "patient_id")

    # Build forest data from available demographics
    forest_data <- data.frame(variable=character(), level=character(),
                               n=integer(), mean_rate=numeric(), ci_lo=numeric(), ci_hi=numeric(),
                               stringsAsFactors = FALSE)
    add_forest <- function(var_name, var_label) {
      vals <- ps2[[var_name]]
      if (all(is.na(vals))) return()
      if (var_name == "hypertension") vals <- ifelse(vals==1,"Yes","No")
      if (var_name == "diabetes_hx")  vals <- ifelse(vals==1,"Yes","No")
      for (lev in unique(na.omit(vals))) {
        rates <- ps2$rate[vals == lev & !is.na(vals)]
        if (length(rates) < 5) next
        m <- mean(rates); se <- sd(rates)/sqrt(length(rates))
        forest_data <<- rbind(forest_data, data.frame(
          variable = var_label, level = lev, n = length(rates),
          mean_rate = m, ci_lo = m - 1.96*se, ci_hi = m + 1.96*se,
          stringsAsFactors = FALSE))
      }
    }
    add_forest("sex_label", "Sex")
    add_forest("ethnicity", "Ethnicity")
    add_forest("age_group", "Age Group")
    add_forest("smoking_label", "Smoking")
    add_forest("bmi_category", "BMI")
    add_forest("hypertension", "Hypertension")
    add_forest("diabetes_hx", "DM History")
    add_forest("insurance", "Insurance")
    add_forest("region", "Region")

    if (nrow(forest_data) < 2) {
      return(NULL)
    }

    forest_data$label <- paste0(forest_data$level, " (n=", forest_data$n, ")")
    forest_data$label <- factor(forest_data$label, levels = rev(forest_data$label))

    ggplot(forest_data, aes(x = mean_rate, y = label, color = variable)) +
      geom_point(size = 3) +
      geom_errorbarh(aes(xmin = ci_lo, xmax = ci_hi), height = 0.25) +
      geom_vline(xintercept = mean(ps2$rate, na.rm=T), linetype="dashed", color="#999") +
      scale_color_viridis_d(option = "H", name = "Variable") +
      labs(title = "Demographic Risk Profile — Event Rate per Year",
           subtitle = "Mean annual event rate with 95% CI by demographic subgroup",
           x = "Events per Patient-Year", y = NULL) +
      theme_app(base_size = 13) +
      theme(legend.position = "right",
            panel.grid.major.y = element_line(color = "#eef2f7"))
  })

  # Demographic table panel
  output$demographic_table_panel <- renderUI({
    req(rv$patient_summary)
    ps <- rv$patient_summary

    # Build a summary table of demographics
    make_row <- function(label, vals) {
      if (all(is.na(vals))) return(NULL)
      if (is.numeric(vals)) {
        return(tags$tr(
          tags$td(style="padding:6px 12px; font-weight:700;", label),
          tags$td(style="padding:6px 12px;",
            paste0("Mean: ", round(mean(vals,na.rm=T),1),
                   " | Median: ", round(median(vals,na.rm=T),1),
                   " | Range: ", round(min(vals,na.rm=T),1), "–", round(max(vals,na.rm=T),1)))))
      } else {
        tab <- table(vals, useNA="no")
        return(tags$tr(
          tags$td(style="padding:6px 12px; font-weight:700;", label),
          tags$td(style="padding:6px 12px;",
            paste(sapply(names(tab), function(x)
              paste0(x, ": ", tab[x], " (", round(100*tab[x]/sum(tab),1), "%)")),
              collapse = " | "))))
      }
    }

    rows <- Filter(Negate(is.null), list(
      make_row("Age", ps$age_baseline),
      make_row("Sex", ps$sex_label),
      make_row("Ethnicity", ps$ethnicity),
      make_row("BMI", ps$bmi),
      make_row("Smoking", ps$smoking_label),
      make_row("Age Group", ps$age_group),
      make_row("BMI Category", ps$bmi_category),
      make_row("Insurance", ps$insurance),
      make_row("Education", ps$education),
      make_row("Region", ps$region)
    ))

    div(
      h4(icon("table"), " Cohort Demographic Summary", style="color:#1b4965; font-weight:700;"),
      tags$table(class = "model-results-table",
        tags$thead(tags$tr(tags$th("Variable"), tags$th("Distribution"))),
        tags$tbody(rows))
    )
  })

  # ════════════════════════════════════════════════════
  #  POPULATION SURFACES
  # ════════════════════════════════════════════════════

  output$ui_surface_transition <- renderUI({
    req(rv$fit_done)
    selectInput("sel_surface_trans", "Select Transition:",
                choices = names(rv$gam_fits))
  })

  output$surface_3d <- renderPlotly({
    req(rv$fit_done, input$sel_surface_trans)
    tr <- input$sel_surface_trans
    if (!(tr %in% names(rv$gam_fits))) return(NULL)
    gf <- rv$gam_fits[[tr]]; ed <- rv$eta_dfs[[tr]]
    eta_cols <- grep("^eta_", names(ed), value = TRUE)
    if (length(eta_cols) < 2) return(NULL)

    e1r <- quantile(ed[[eta_cols[1]]], c(0.05, 0.95), na.rm = TRUE)
    e2r <- quantile(ed[[eta_cols[2]]], c(0.05, 0.95), na.rm = TRUE)
    e1s <- seq(e1r[1], e1r[2], length = 40); e2s <- seq(e2r[1], e2r[2], length = 40)
    g <- expand.grid(x1 = e1s, x2 = e2s); names(g) <- eta_cols[1:2]
    g$age_baseline <- 63; g$sex <- 0
    if (length(eta_cols) >= 3) g[[eta_cols[3]]] <- median(ed[[eta_cols[3]]], na.rm = TRUE)
    # Add required columns for cox.ph predict
    g$time_in_state <- median(ed$time_in_state, na.rm = TRUE)
    g$status <- 1

    g$z <- tryCatch(predict(gf, newdata = g, type = "link"),
                     error = function(e) predict(gf, newdata = g, type = "terms")[,1])
    z_mat <- matrix(g$z, nrow = length(e1s))

    plot_ly(x = e1s, y = e2s, z = z_mat, type = "surface",
            colorscale = "Viridis") %>%
      layout(title = list(text = paste("Association Surface:", tr),
                           font = list(size = 16, color = "#1b4965")),
             scene = list(xaxis = list(title = gsub("eta_","",eta_cols[1])),
                          yaxis = list(title = gsub("eta_","",eta_cols[2])),
                          zaxis = list(title = "log-Hazard")))
  })

  output$surface_heatmap <- renderPlot({
    req(rv$fit_done, input$sel_surface_trans)
    tr <- input$sel_surface_trans
    if (!(tr %in% names(rv$gam_fits))) return(NULL)
    gf <- rv$gam_fits[[tr]]; ed <- rv$eta_dfs[[tr]]
    eta_cols <- grep("^eta_", names(ed), value = TRUE)
    if (length(eta_cols) < 2) return(NULL)

    e1r <- quantile(ed[[eta_cols[1]]], c(0.05,0.95), na.rm=TRUE)
    e2r <- quantile(ed[[eta_cols[2]]], c(0.05,0.95), na.rm=TRUE)
    e1s <- seq(e1r[1],e1r[2],length=50); e2s <- seq(e2r[1],e2r[2],length=50)
    g <- expand.grid(x1=e1s, x2=e2s); names(g) <- eta_cols[1:2]
    g$age_baseline <- 63; g$sex <- 0
    if (length(eta_cols)>=3) g[[eta_cols[3]]] <- median(ed[[eta_cols[3]]],na.rm=TRUE)
    g$time_in_state <- median(ed$time_in_state, na.rm=TRUE); g$status <- 1
    g$z <- tryCatch(predict(gf, newdata=g, type="link"),
                     error = function(e) predict(gf, newdata=g, type="terms")[,1])

    ggplot(g, aes_string(x=eta_cols[1], y=eta_cols[2], fill="z")) +
      geom_tile() + scale_fill_viridis(option="C", name="log-HR") +
      labs(title=paste("Heatmap:", tr), x=gsub("eta_","",eta_cols[1]),
           y=gsub("eta_","",eta_cols[2])) +
      theme_minimal(base_size=13) +
      theme(plot.title = element_text(face="bold", colour="#1b4965", size=16),
            axis.title = element_text(face="bold", size=13)) +
      coord_fixed()
  })

  output$marginal_slices <- renderPlot({
    req(rv$fit_done, input$sel_surface_trans)
    tr <- input$sel_surface_trans
    if (!(tr %in% names(rv$gam_fits))) return(NULL)
    gf <- rv$gam_fits[[tr]]; ed <- rv$eta_dfs[[tr]]
    eta_cols <- grep("^eta_", names(ed), value=TRUE)
    if (length(eta_cols) < 2) return(NULL)

    e1s <- seq(quantile(ed[[eta_cols[1]]],0.05,na.rm=T),
               quantile(ed[[eta_cols[1]]],0.95,na.rm=T), length=60)
    q2 <- quantile(ed[[eta_cols[2]]], c(0.25,0.5,0.75), na.rm=TRUE)

    slices <- do.call(rbind, lapply(q2, function(qv) {
      s <- data.frame(x=e1s); names(s) <- eta_cols[1]
      s[[eta_cols[2]]] <- qv; s$age_baseline <- 63; s$sex <- 0
      if(length(eta_cols)>=3) s[[eta_cols[3]]] <- median(ed[[eta_cols[3]]],na.rm=T)
      s$time_in_state <- median(ed$time_in_state, na.rm=T); s$status <- 1
      s$f_hat <- tryCatch(predict(gf, newdata=s, type="link"),
                           error = function(e) predict(gf, newdata=s, type="terms")[,1])
      s$q_lab <- sprintf("%s=%.0f", gsub("eta_","",eta_cols[2]), qv)
      s
    }))

    ggplot(slices, aes_string(x=eta_cols[1], y="f_hat", colour="q_lab")) +
      geom_line(linewidth=1.8) +
      scale_colour_manual(values=c("#264653","#e76f51","#2a9d8f")) +
      labs(title="Marginal Effect Slices", x=gsub("eta_","",eta_cols[1]),
           y="Partial log-HR", colour=NULL) +
      theme_app(base_size = 14)
  })

  output$surface_interpretation <- renderUI({
    req(rv$fit_done, input$sel_surface_trans)
    tr <- input$sel_surface_trans
    if (!(tr %in% names(rv$gam_fits))) return(NULL)
    gf <- rv$gam_fits[[tr]]
    edf <- if(length(summary(gf)$s.table)>0) round(summary(gf)$s.table[1,"edf"],1) else "?"
    de <- round(summary(gf)$dev.expl*100, 1)

    div(class = "interp-panel",
      h4(icon("chart-area"), paste(" Surface for:", tr)),
      p(paste0("Effective degrees of freedom: ", edf,
               ". Deviance explained: ", de, "%."),
        if(is.numeric(edf) && edf > 4)
          " The surface is substantially nonlinear — biomarker interactions
            play a major role in this transition."
        else if(is.numeric(edf) && edf > 2)
          " Moderate nonlinearity detected — some interaction or threshold effects present."
        else
          " Nearly linear relationship — standard parametric models may suffice for this transition."),
      p("The heatmap shows 'danger zones' where combinations of biomarker values
         create elevated transition hazard. Marginal slices show how one biomarker's
         effect changes depending on the level of the other.")
    )
  })
}

shinyApp(ui, server)
