---
title: "Simple unit tests with in built-in datasets"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{data-examples}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
  %\VignetteDepends{tidyverse, RColorBrewer, colorspace, patchwork, randomForest}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
```

```{r setup}
library(kumquat)
options(repos = c(CRAN = "https://cloud.r-project.org"))
if(!requireNamespace("tidyverse")) {
  install.packages("tidyverse")
}
if(!requireNamespace("RColorBrewer")) {
  install.packages("RColorBrewer")
}
if(!requireNamespace("colorspace")) {
  install.packages("colorspace")
}
if(!requireNamespace("patchwork")) {
  install.packages("patchwork")
}
if(!requireNamespace("randomForest")) {
  install.packages("randomForest")
}
library(tidyverse)
library(RColorBrewer)
library(colorspace)
library(patchwork)
library(randomForest)
```

There are four sample datasets in the package, with varying complexities in the decision boundary. 

The datasets are as follows.

- d_vertical
- d_oblique
- d_multi
- d_multitwo

All of these datasets are made for a binary classification task. Each dataset contains two numeric variables (`x`, `y`) and one categorical variable (`class`). In this section, we will visualize the datasets along with their decision boundary.

<!-- ## `d_vertical` dataset -->

```{r, echo=FALSE}
data("d_vertical")

d_vert_plot <- ggplot(d_vertical, aes(x = x, y = y, colour = class)) +
  geom_point() +
  # scale_colour_discrete_divergingx(palette = "Zissou 1") +
  scale_color_brewer(palette = "Dark2") +
  theme_minimal() +
  theme(aspect.ratio = 1) +
  labs(title = "d_vertical dataset")
```

<!-- ## `d_oblique` dataset -->

```{r, echo=FALSE}
data(d_oblique)

d_obl_plot <- ggplot(data=d_oblique, aes(x = x, y = y, colour = class)) +
  geom_point() +
  # scale_colour_discrete_divergingx(palette = "Zissou 1") +
  scale_color_brewer(palette = "Dark2") +
  theme_minimal() +
  theme(aspect.ratio = 1) +
  labs(title = "d_oblique dataset")
```

<!-- ## `d_multi` dataset -->

```{r, echo=FALSE}
data("d_multi")

d_multi_plot <- ggplot(d_multi, aes(x = x, y = y, colour = class)) +
  geom_point() +
  # scale_colour_discrete_divergingx(palette = "Zissou 1") +
  scale_color_brewer(palette = "Dark2") +
  theme_minimal() +
  theme(aspect.ratio = 1) +
  labs(title = "d_multi dataset")
```

<!-- ## `d_multitwo` dataset -->

```{r, echo=FALSE}
data("d_multitwo")

d_multitwo_plot <- ggplot(d_multitwo, aes(x = x, y = y, colour = class)) +
  geom_point() +
  # scale_colour_discrete_divergingx(palette = "Zissou 1") +
  scale_color_brewer(palette = "Dark2") +
  theme_minimal() +
  theme(aspect.ratio = 1) +
  labs(title = "d_multitwo dataset")
```

```{r}

(d_vert_plot + d_obl_plot) / (d_multi_plot + d_multitwo_plot)
```

## Testing kumquat with the given datasets

### Models

```{r, echo=FALSE}
rf_multitwo <- randomForest(class ~ ., data = d_multitwo)
rf_multitwo_bundle <- bundle::bundle(rf_multitwo)
```

```{r, echo=FALSE}
find_closest <- function(pt, data) {
  dst <- data |>
    mutate(dst = sqrt((x - pt$x)^2 + (y - pt$y)^2))
  return(which.min(dst$dst))
}
```


```{r, echo=FALSE}
pt <- tibble(x=0.39, y=0.4)
obs <- find_closest(pt, d_multitwo)

ks <- kumquat(
  rf_multitwo_bundle,
  d_multitwo,
  obs,
  class_names = unique(d_multitwo$class)
)
```

```{r}
# we expect the absolute importance of x to be greater for y
pinch_importance(ks)
```

## Visualizing kumquat with the given datasets

```{r}
plot_interest(ks)[[1]]
```

