## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
    collapse = TRUE,
    comment = "#>",
    fig.width = 7,
    fig.height = 5.5,
    fig.align = "center",
    warning = FALSE
)
set.seed(42)
par(bty = "n")
library(yaap)

## ----helpers, include = FALSE-------------------------------------------------
pal_aa <- c("#0072B2", "#D55E00", "#009E73", "#CC79A7")
pal_data <- adjustcolor("black", 0.20)
pal_data_mid <- adjustcolor("black", 0.45)

sample_simplex <- function(n, k) {
    Z <- matrix(stats::rexp(n * k), nrow = n, ncol = k)
    Z / rowSums(Z)
}

final_loss <- function(fit) {
    tail(fit[["loss"]], 1)
}

unit_rows <- function(X) {
    X <- as.matrix(X)
    X / sqrt(rowSums(X * X))
}

## ----metric-simplex-data------------------------------------------------------
# Define the signal archetypes.
A_signal_true <- rbind(
    left  = c(-1,      0, 0, 0),
    right = c(1,       0, 0, 0),
    top   = c(0, sqrt(3), 0, 0)
)

# Basic dimensions
K <- nrow(A_signal_true)
M <- ncol(A_signal_true)
N <- 220  # samples

# Archetype compositions for each sample
S_metric_true <- sample_simplex(N, K)
X_metric_mean <- S_metric_true %*% A_signal_true

# Build block-diagonal measurement-noise covariance
Sigma_noise <- matrix(0, nrow = M, ncol = M)
# Signal features: small, moderately correlated noise
Sigma_noise[1:2, 1:2] <- 0.03^2 * matrix(c(1, -0.5, -0.5, 1), nrow = 2)
# Nuisance features: large, weakly correlated noise
Sigma_noise[3:4, 3:4] <- 3^2 * matrix(c(1, 0.2, 0.2, 1), nrow = 2)
# Precision matrix used as metric for AA
G_noise <- solve(Sigma_noise)

# Correlated anisotropic noise
X_noise <- matrix(rnorm(N * M), ncol = M) %*% chol(Sigma_noise)
X_metric <- X_metric_mean + X_noise
colnames(X_metric) <- c("feature_1", "feature_2", "noise_1", "noise_2")

## ----metric-simplex-fit, warning=FALSE----------------------------------------
set.seed(42)
fit_plain <- run_aa(
    X_metric,
    K = K,
    scale = FALSE,
    sd_threshold = 0,
    init = "random",
    nrep = 20
)
fit_plain

set.seed(42)  # same starts
fit_metric <- run_aa(
    X_metric,
    K = K,
    scale = G_noise,
    sd_threshold = 0,
    init = "random",
    nrep = 20
)
fit_metric

## ----metric-simplex-loss, echo=FALSE------------------------------------------
match_to_truth <- function(fit, truth, G = diag(ncol(truth))) {
    perms <- rbind(
        c(1, 2, 3), c(1, 3, 2), c(2, 1, 3),
        c(2, 3, 1), c(3, 1, 2), c(3, 2, 1)
    )
    error <- apply(perms, 1, function(ix) {
        E <- coordinates(fit)[ix, , drop = FALSE] - truth
        sum(diag(E %*% G %*% t(E)))
    })
    perms[which.min(error), ]
}

noise_metric_archetype_error <- function(fit, truth, G) {
    ix <- match_to_truth(fit, truth, G)
    E <- coordinates(fit)[ix, , drop = FALSE] - truth
    sum(diag(E %*% G %*% t(E)))
}

clr <- function(x) {
    logx <- log2(pmax(x, 1e-6))
    logx - rowMeans(logx)
}

composition_rmse <- function(fit, truth, S_true, G) {
    ix <- match_to_truth(fit, truth, G)
    clr_true <- clr(S_true)
    clr_fit <- clr(compositions(fit)[, ix, drop = FALSE])
    sqrt(mean((clr_fit - clr_true)^2))
}

res <- rbind(
    plain = cbind(
        final_loss(fit_plain),
        data.frame(
            noise_metric_archetype_error = noise_metric_archetype_error(
                fit_plain, A_signal_true, G_noise
            ),
            composition_rmse = composition_rmse(
                fit_plain, A_signal_true, S_metric_true, G_noise
            )
        )
    ),
    metric = cbind(
        final_loss(fit_metric),
        data.frame(
            noise_metric_archetype_error = noise_metric_archetype_error(
                fit_metric, A_signal_true, G_noise
            ),
            composition_rmse = composition_rmse(
                fit_metric, A_signal_true, S_metric_true, G_noise
            )
        )
    )
)
knitr::kable(
    as.data.frame(res),
    digits = c(0, 3, 2, 2),
    col.names = c("Final loss", "R2", "Noise-metric archetype error", "Composition RMSE")
)

## ----metric-simplex-plot, echo=FALSE, fig.cap = "Metric Gaussian AA with known correlated measurement noise."----
plot(
    X_metric[, 1:2], frame.plot = FALSE,
    asp = 1, pch = 19, cex = 0.45, col = pal_data,
    xlab = "feature 1", ylab = "feature 2", main = "Metric Gaussian AA"
)
polygon(A_signal_true[c(1, 3, 2), ], border = "black", lty = 3, lwd = 1.4)
points(A_signal_true, pch = 15, cex = 1.1, col = "black")
points(coordinates(fit_plain)[, 1:2], pch = 4, cex = 1.45, lwd = 1.7, col = pal_aa[2])
points(coordinates(fit_metric)[, 1:2], pch = 1, cex = 1.65, lwd = 1.8, col = pal_aa[1])
legend(
    "topright",
    legend = c("Data", "True archetypes", "Plain Euclidean", "Metric"),
    pch = c(19, 15, 4, 1),
    col = c(pal_data, "black", pal_aa[2], pal_aa[1]),
    pt.cex = c(0.7, 1.1, 1.3, 1.5),
    bty = "n"
)

## ----functional-aa-fit, eval=requireNamespace("fda", quietly = TRUE)----------
N <- 90  # number of curves
t_fd <- seq(0, 1, length.out = 80)  # time points for functional data
A_fd_true <- rbind(
    early_peak = exp(-60 * (t_fd - 0.18)^2),
    late_peak  = exp(-60 * (t_fd - 0.82)^2),
    ramp       = 0.15 + 0.85 * t_fd
)
# Archetype compositions for each sample
S_fd_true <- sample_simplex(N, 3)
X_fd_noise <- matrix(rnorm(N * length(t_fd), sd = 0.025), nrow = N)
X_fd <- S_fd_true %*% A_fd_true + X_fd_noise

# Convert the sampled curves to an fda::fd object
basis_fd <- fda::create.bspline.basis(rangeval = c(0, 1), nbasis = 14)
fd_data <- fda::Data2fd(argvals = t_fd, y = t(X_fd), basisobj = basis_fd)

fit_fd <- run_aa(fd_data, K = 3, sd_threshold = 0)
fit_fd

## ----functional-aa-plot, eval=requireNamespace("fda", quietly = TRUE), fig.cap = "Functional Gaussian AA fitted directly to an fda::fd object."----

A_fit_fd <- coordinates(fit_fd)
class(A_fit_fd)  # fitted archetype curves

sample_ix <- seq(1, nrow(X_fd), length.out = 25) # sample curves for plotting

# Plot sampled curves, then overlay archetype curves
matplot(
    t_fd,
    t(X_fd[sample_ix, ]),
    type = "l", lty = 1, col = pal_data,
    xlab = "t", ylab = "Value", main = "Functional AA"
)
invisible( # plot.fd prints a "done" message that we want to suppress
    plot(A_fit_fd, lty = 1, lwd = 2, col = pal_aa[1:3], add = TRUE)
)
legend(
    "top",
    legend = anames(fit_fd),
    lty = 1, lwd = 2, col = pal_aa[1:3],
    bty = "n", horiz = TRUE
)

## ----kernel-data--------------------------------------------------------------
# Choose the sample sizes
n_outer <- 120
n_inner <- 80

# Simulate an outer ring
theta_outer <- runif(n_outer, 0, 2 * pi)
r_outer <- 1 + 0.12 * runif(n_outer)
X_outer <- cbind(r_outer * cos(theta_outer), r_outer * sin(theta_outer))

# Simulate an inner core
theta_inner <- runif(n_inner, 0, 2 * pi)
r_inner <- 0.15 * runif(n_inner)
X_inner <- cbind(r_inner * cos(theta_inner), r_inner * sin(theta_inner))

# Combine the two groups
X_ring <- rbind(X_outer, X_inner)
colnames(X_ring) <- c("x", "y")

## ----kernel-fit, warning=FALSE------------------------------------------------
set.seed(42)
fit_linear <- run_aa(
    X_ring,
    K = 7,
    scale = FALSE,
    init = "random",
    nrep = 20
)
fit_linear

set.seed(42)
fit_kernel <- run_aa(
    X_ring,
    K = 7,
    method = "kernel",
    kernel = "laplace",
    kernel_args = list(sigma = 0.35),
    init = "random",
    nrep = 20
)
fit_kernel

## ----kernel-plot, fig.cap = "Linear AA archetypes and kernel-AA input-space proxy archetypes."----
plot(
    X_ring, frame.plot = FALSE, axes = FALSE,
    col = pal_data, asp = 1, pch = 19, cex = 0.45,
    main = "Linear AA vs kernel-AA proxies", xlab = "", ylab = ""
)
points(coordinates(fit_linear), pch = 4, cex = 1.25, lwd = 1.5, col = pal_aa[1])
points(coordinates(fit_kernel), pch = 1, cex = 1.45, lwd = 1.6, col = pal_aa[2])
legend(
    "topleft",
    legend = c("Data", "Linear AA", "Kernel-AA"),
    pch = c(19, 4, 1),
    col = c(pal_data, pal_aa[1:2]),
    pt.cex = c(0.7, 1.2, 1.4),
    bty = "n"
)

## ----paa-binomial-data--------------------------------------------------------
N <- 120
P_true <- rbind(
    broad_yes = c(0.85, 0.80, 0.75, 0.20, 0.15, 0.10),
    middle    = c(0.25, 0.75, 0.35, 0.70, 0.35, 0.65),
    broad_no  = c(0.10, 0.15, 0.20, 0.80, 0.85, 0.75)
)
K <- nrow(P_true)
M <- ncol(P_true)
S_bin <- sample_simplex(N, K)
P_bin <- S_bin %*% P_true
X_bin <- matrix(rbinom(N * M, size = 1, prob = as.vector(P_bin)), nrow = N)
colnames(X_bin) <- paste0("item_", seq_len(M))
head(X_bin)

## ----paa-binomial-fit---------------------------------------------------------
fit_binomial <- run_aa(
    X_bin,
    K = K,
    method = "paa",
    family = "binomial"
)

fit_binomial

## ----paa-profile, echo=FALSE, fig.cap = "Binomial PAA archetype response probabilities."----
old_par <- par(mar = c(5.1, 4.1, 4.1, 8.5), xpd = TRUE)
plot(
    fit_binomial,
    what = "profiles",
    col = pal_aa[1:3],
    ylim = c(0, 1),
    main = "Binomial archetype probabilities",
    args.legend = list(
        x = "topright",
        inset = c(-0.32, 0),
        bty = "n"
    )
)
par(old_par)

## ----paa-predict--------------------------------------------------------------
S_new <- predict(fit_binomial, X_bin[1:5, ], type = "compositions")
round(S_new, 3)

P_new <- predict(fit_binomial, X_bin[1:5, ], type = "reconstruction")
round(P_new, 2)

## ----paa-multinomial----------------------------------------------------------
# Define term probabilities for three document types
Topic_true <- rbind(
    visual = c(0.55, 0.25, 0.10, 0.05, 0.05),
    text   = c(0.05, 0.10, 0.55, 0.25, 0.05),
    mixed  = c(0.15, 0.25, 0.15, 0.20, 0.25)
)
# Basic dimensions
K <- nrow(Topic_true)
M <- ncol(Topic_true)
N <- 80
# Archetype compositions
S_txt <- sample_simplex(N, K)
Theta <- S_txt %*% Topic_true

totals <- 39 + sample(40, N, replace = TRUE) # total term counts per document
X_txt <- matrix(0L, nrow = N, ncol = M)     # observed term counts
for (i in seq_len(N)) {
    X_txt[i, ] <- as.integer(rmultinom(1, totals[i], Theta[i, ]))
}
colnames(X_txt) <- paste0("term_", seq_len(ncol(X_txt)))

fit_multi <- run_aa(
    X_txt,
    K = K,
    method = "paa",
    family = "multinomial",
    tol = 1e-6
)
fit_multi

P_hat <- fitted(fit_multi)
head(round(P_hat, 2), 3)
rowSums(P_hat[1:3, ])

expected_counts <- rowSums(X_txt) * P_hat
head(round(expected_counts, 1), 3)

## ----directional-data---------------------------------------------------------
A_dir_true <- rbind(
    c(1, 0, 0),
    c(0, 1, 0),
    c(0, 0.15, 1)
)
K <- nrow(A_dir_true)
N <- 120
S_dir <- sample_simplex(N, K)
X_dir <- unit_rows(S_dir %*% A_dir_true)
flip <- sample(c(-1, 1), nrow(X_dir), replace = TRUE)
X_dir_flip <- X_dir * flip
colnames(X_dir_flip) <- c("x", "y", "z")

## ----directional-fit----------------------------------------------------------
set.seed(42)
fit_euclidean_dir <- run_aa(
    X_dir_flip,
    K = K,
    sd_threshold = 0,
    init = "random",
    nrep = 20
)
fit_euclidean_dir

set.seed(42)
fit_directional <- run_aa(
    X_dir_flip,
    K = K,
    method = "directional",
    sd_threshold = 0,
    init = "random",
    nrep = 20
)
fit_directional

## ----directional-directions---------------------------------------------------
round(fit_directional[["directions"]], 3)
round(rowSums(fitted(fit_directional)^2), 6)[1:6]

## ----directional-plot, echo=FALSE, fig.cap = "Orthographic view of polarity-flipped spherical data."----
grid_t <- seq(0, 2 * pi, length.out = 200)
grid_lat <- seq(-60, 60, by = 30) * pi / 180
grid_lon <- seq(-60, 60, by = 30) * pi / 180
view_tilt <- 25 * pi / 180
view_turn <- 15 * pi / 180
rot_x <- matrix(
    c(1, 0, 0,
      0, cos(view_tilt), -sin(view_tilt),
      0, sin(view_tilt), cos(view_tilt)),
    nrow = 3, byrow = TRUE
)
rot_z <- matrix(
    c(cos(view_turn), -sin(view_turn), 0,
      sin(view_turn), cos(view_turn), 0,
      0, 0, 1),
    nrow = 3, byrow = TRUE
)
project_sphere <- function(X) {
    as.matrix(X) %*% t(rot_z %*% rot_x)
}

plot(
    NA, NA, axes = FALSE,
    xlim = c(-1.05, 1.05), ylim = c(-1.05, 1.05), asp = 1,
    main = "Directional AA on the unit sphere", xlab = "", ylab = ""
)
polygon(
    cos(grid_t), sin(grid_t),
    border = adjustcolor("black", 0.35), col = adjustcolor("black", 0.03),
    lwd = 1.2
)
for (lat in grid_lat) {
    grid_line <- cbind(
        cos(lat) * cos(grid_t),
        cos(lat) * sin(grid_t),
        sin(lat)
    )
    grid_line <- project_sphere(grid_line)
    lines(grid_line[, 1:2], col = adjustcolor("black", 0.12), lwd = 0.8)
}
for (lon in grid_lon) {
    grid_lat_dense <- seq(-pi / 2, pi / 2, length.out = 200)
    grid_line <- cbind(
        cos(grid_lat_dense) * cos(lon),
        cos(grid_lat_dense) * sin(lon),
        sin(grid_lat_dense)
    )
    grid_line <- project_sphere(grid_line)
    lines(grid_line[, 1:2], col = adjustcolor("black", 0.12), lwd = 0.8)
}

X_dir_view <- project_sphere(X_dir_flip)
A_euclidean_view <- project_sphere(unit_rows(coordinates(fit_euclidean_dir)))
A_directional_view <- project_sphere(fit_directional[["directions"]])
front <- X_dir_view[, 3] >= 0
points(X_dir_view[!front, 1:2], pch = 19, cex = 0.48, col = pal_data)
points(X_dir_view[front, 1:2], pch = 19, cex = 0.52, col = pal_data_mid)
points(
    A_euclidean_view[, 1:2],
    pch = 4, cex = 1.45, lwd = 1.7, col = pal_aa[2]
)
points(
    A_directional_view[, 1:2],
    pch = 16, cex = 1.25, col = pal_aa[1]
)
legend(
    "bottomleft",
    legend = c("Back hemisphere", "Front hemisphere", "Euclidean AA", "Directional AA"),
    pch = c(19, 19, 4, 16),
    col = c(pal_data, pal_data_mid, pal_aa[2], pal_aa[1]),
    pt.cex = c(0.7, 0.8, 1.2, 1.1),
    bty = "n"
)

