## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(kumquat)
options(repos = c(CRAN = "https://cloud.r-project.org"))
if(!requireNamespace("tidyverse")) {
  install.packages("tidyverse")
}
if(!requireNamespace("RColorBrewer")) {
  install.packages("RColorBrewer")
}
if(!requireNamespace("colorspace")) {
  install.packages("colorspace")
}
if(!requireNamespace("patchwork")) {
  install.packages("patchwork")
}
if(!requireNamespace("randomForest")) {
  install.packages("randomForest")
}
library(tidyverse)
library(RColorBrewer)
library(colorspace)
library(patchwork)
library(randomForest)

## ----echo=FALSE---------------------------------------------------------------
data("d_vertical")

d_vert_plot <- ggplot(d_vertical, aes(x = x, y = y, colour = class)) +
  geom_point() +
  # scale_colour_discrete_divergingx(palette = "Zissou 1") +
  scale_color_brewer(palette = "Dark2") +
  theme_minimal() +
  theme(aspect.ratio = 1) +
  labs(title = "d_vertical dataset")

## ----echo=FALSE---------------------------------------------------------------
data(d_oblique)

d_obl_plot <- ggplot(data=d_oblique, aes(x = x, y = y, colour = class)) +
  geom_point() +
  # scale_colour_discrete_divergingx(palette = "Zissou 1") +
  scale_color_brewer(palette = "Dark2") +
  theme_minimal() +
  theme(aspect.ratio = 1) +
  labs(title = "d_oblique dataset")

## ----echo=FALSE---------------------------------------------------------------
data("d_multi")

d_multi_plot <- ggplot(d_multi, aes(x = x, y = y, colour = class)) +
  geom_point() +
  # scale_colour_discrete_divergingx(palette = "Zissou 1") +
  scale_color_brewer(palette = "Dark2") +
  theme_minimal() +
  theme(aspect.ratio = 1) +
  labs(title = "d_multi dataset")

## ----echo=FALSE---------------------------------------------------------------
data("d_multitwo")

d_multitwo_plot <- ggplot(d_multitwo, aes(x = x, y = y, colour = class)) +
  geom_point() +
  # scale_colour_discrete_divergingx(palette = "Zissou 1") +
  scale_color_brewer(palette = "Dark2") +
  theme_minimal() +
  theme(aspect.ratio = 1) +
  labs(title = "d_multitwo dataset")

## -----------------------------------------------------------------------------

(d_vert_plot + d_obl_plot) / (d_multi_plot + d_multitwo_plot)

## ----echo=FALSE---------------------------------------------------------------
rf_multitwo <- randomForest(class ~ ., data = d_multitwo)
rf_multitwo_bundle <- bundle::bundle(rf_multitwo)

## ----echo=FALSE---------------------------------------------------------------
find_closest <- function(pt, data) {
  dst <- data |>
    mutate(dst = sqrt((x - pt$x)^2 + (y - pt$y)^2))
  return(which.min(dst$dst))
}

## ----echo=FALSE---------------------------------------------------------------
pt <- tibble(x=0.39, y=0.4)
obs <- find_closest(pt, d_multitwo)

ks <- kumquat(
  rf_multitwo_bundle,
  d_multitwo,
  obs,
  class_names = unique(d_multitwo$class)
)

## -----------------------------------------------------------------------------
# we expect the absolute importance of x to be greater for y
pinch_importance(ks)

## -----------------------------------------------------------------------------
plot_interest(ks)[[1]]

