## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----include = FALSE, eval = FALSE--------------------------------------------
#  #' @srrstats {G1.5} These examples are based on classical test problems for
#  #' nested sampling used within the nestle package. These problems are
#  #' also used within ernest's test suite to validate its behaviour.
#  #' @srrstats {G5.1} Tests and test data are provided here for users to
#  #' interact with.
#  #' @srrstats {BS1.1} Demonstrates how to enter data through a likelihood
#  #' function.

## ----setup, message = FALSE---------------------------------------------------
library(ernest)
library(posterior)
library(ggplot2)

## -----------------------------------------------------------------------------
# Log-likelihood for two Gaussian blobs
sigma <- 0.1
mu1 <- c(1, 1)
mu2 <- -c(1, 1)
sigma_inv <- diag(2) / 0.1**2

gaussian_blobs_loglik <- function(x) {
  dx1 <- -0.5 * mahalanobis(x, c(1, 1), sigma_inv, inverted = TRUE)
  dx2 <- -0.5 * mahalanobis(x, c(-1, -1), sigma_inv, inverted = TRUE)
  matrixStats::colLogSumExps(rbind(dx1, dx2))
}

# Uniform prior over [-5, 5] in each dimension
prior <- create_uniform_prior(lower = -5, upper = 5, names = c("A", "B"))

# Run the model
blob_sampler <- ernest_sampler(
  create_likelihood(vectorized_fn = gaussian_blobs_loglik),
  prior,
  nlive = 100,
  seed = 42L
)
blob_result <- generate(blob_sampler, show_progress = FALSE)

## -----------------------------------------------------------------------------
summary(blob_result)
calculate(blob_result, ndraws = 500)$log_evidence |>
  tail(1)

## ----include = FALSE----------------------------------------------------------
rm(blob_sampler, blob_result)

## -----------------------------------------------------------------------------
# Log-likelihood for the eggbox function
eggbox_loglik <- function(x) {
  tmax <- 5.0 * pi
  if (!is.matrix(x)) dim(x) <- c(1, length(x))
  t <- sweep(2.0 * tmax * x, 2, tmax, "-")
  (2.0 + cos(t[, 1] / 2.0) * cos(t[, 2] / 2.0))^5.0
}

# Uniform prior over [0, 1] in each dimension
eggbox_prior <- create_uniform_prior(names = c("A", "B"))

## ----echo = FALSE-------------------------------------------------------------
eggbox_sample <- expand.grid(
  "A" = seq(0, 1, by = 0.01),
  "B" = seq(0, 1, by = 0.01)
)
eggbox_sample$logl <- mapply(
  function(a, b) eggbox_loglik(c(a, b)),
  eggbox_sample$A,
  eggbox_sample$B
)

library(ggplot2)
ggplot(eggbox_sample, aes(A, B, fill = logl)) +
  geom_tile() +
  scale_fill_viridis_c("Log-Lik.")
rm(eggbox_sample)

## -----------------------------------------------------------------------------
egg_sampler <- ernest_sampler(
  eggbox_loglik,
  eggbox_prior,
  sampler = multi_ellipsoid(),
  seed = 42L
)
egg_result <- generate(egg_sampler, show_progress = FALSE)
summary(egg_result)

## -----------------------------------------------------------------------------
visualize(egg_result, .which = "trace")

## ----include = FALSE----------------------------------------------------------
rm(egg_sampler, egg_result)

## -----------------------------------------------------------------------------
y <- c(0.2, 0.1, 0.3, 0.1, 0.3, 0.1, 0.3, 0.1, 0.3, 0.1) + 1e+08

## ----include = FALSE----------------------------------------------------------
ref <- vctrs::data_frame(
  variable = c("mu", "sigma"),
  `2.5%` = c(100000000.13281908588, 0.069871704416342),
  `50%` = c(100000000.20000000000, 0.103462818336964),
  `97.5%` = c(100000000.26718091412, 0.175493354741336)
)

## -----------------------------------------------------------------------------
gaussian_log_lik <- function(data) {
  force(data)

  function(theta) {
    if (theta[2] <= 0) return(-Inf)
    sum(stats::dnorm(data, mean = theta[1], sd = theta[2], log = TRUE))
  }
}

log_lik <- gaussian_log_lik(y)
prior <- create_uniform_prior(
  names = c("mu", "sigma"),
  lower = c(1e+08 - 1, 0.01),
  upper = c(1e+08 + 1, 1)
)
nist_sampler <- ernest_sampler(log_lik, prior, seed = 42L)
nist_result <- generate(nist_sampler, show_progress = FALSE)

## -----------------------------------------------------------------------------
post <- as_draws(nist_result) |>
  resample_draws() |>
  summarise_draws(\(x) quantile(x, probs = c(0.025, 0.5, 0.975)))
post

## ----echo = FALSE-------------------------------------------------------------
all <- rbind(post, ref)
all$src <- rep(c("est", "act"), each = 2)
all

ggplot(all, aes(x = src)) +
  geom_crossbar(aes(ymin = `2.5%`, y = `50%`, ymax = `97.5%`)) +
  facet_grid(rows = vars(variable), scales = "free_y") +
  scale_y_continuous("Value") +
  scale_x_discrete(
    "Data Source",
    breaks = c("act", "est"),
    labels = c("Actual", "Expected")
  )

