## -----------------------------------------------------------------------------
library(dplyr)
library(tidyr)
library(tibble)
library(purrr)
library(stringr)
library(ggplot2)
library(lpSolve) # For matching fitted and trust topics
library(tictoc) # For checking timing

library(tmfast)
library(stm) # Used for comparison
library(tidytext) # Used for stm tidiers

## -----------------------------------------------------------------------------
k = 10 # Num. topics / journals
Mj = 100 # Num. documents per journal
M = Mj * k # Total corpus size
vocab = M # Vocabulary length

## Negative binomial distribution of doc lengths
size = 10 # Size and mean
mu = 300
sqrt(mu + mu^2 / size) # Resulting SD of document sizes

## Dirichlet distributions for topic-docs and word-topics
topic_peak = .8
topic_scale = 10

word_beta = 0.1

## -----------------------------------------------------------------------------
set.seed(2022 - 06 - 19)

## -----------------------------------------------------------------------------
## Journal-specific alpha, with a peak value (.8 by default) and uniform otherwise
theta = map(
      1:k,
      ~ rdirichlet(
            Mj,
            peak_alpha(k, .x, peak = topic_peak, scale = topic_scale)
      )
) |>
      reduce(rbind)

theta_df = theta |>
      as_tibble(rownames = 'doc', .name_repair = tmfast:::make_colnames) |>
      mutate(doc = as.integer(doc)) |>
      pivot_longer(starts_with('V'), names_to = 'topic', values_to = 'prob')

ggplot(theta_df, aes(doc, topic, fill = prob)) +
      geom_tile()

## -----------------------------------------------------------------------------
## phi_j:  Word distribution for topic j
phi = rdirichlet(k, word_beta, k = vocab)

## Word distributions
phi |>
      as_tibble(rownames = 'topic', .name_repair = tmfast:::make_colnames) |>
      pivot_longer(starts_with('V'), names_to = 'word', values_to = 'prob') |>
      ggplot(aes(topic, word, fill = (prob))) +
      geom_tile() +
      scale_y_discrete(breaks = NULL)

## Zipf's law
phi |>
      as_tibble(rownames = 'topic', .name_repair = \(x) {
            (str_c('word', 1:vocab))
      }) |>
      pivot_longer(
            starts_with('word'),
            names_to = 'word',
            values_to = 'prob'
      ) |>
      group_by(topic) |>
      mutate(rank = rank(desc(prob))) |>
      arrange(topic, rank) |>
      filter(rank < vocab / 2) |>
      ggplot(aes(rank, prob, color = topic)) +
      geom_line() +
      scale_x_log10() +
      scale_y_log10()

## -----------------------------------------------------------------------------
## N_i:  Length of document i
N = rnbinom(M, size = size, mu = mu)
summary(N)
sd(N)
hist(N)

## ----cache = TRUE-------------------------------------------------------------
tic()
corpus = draw_corpus(N, theta, phi)
toc()
dtm = mutate(corpus, n = log1p(n))

## -----------------------------------------------------------------------------
tic()
fitted = tmfast(dtm, c(2, 3, k, 2 * k))
toc()

## -----------------------------------------------------------------------------
str(fitted, max.level = 2L)
str(fitted$varimax$`10`)

## -----------------------------------------------------------------------------
screeplot(fitted)

## -----------------------------------------------------------------------------
## Variance coverage?
cumsum(fitted$sdev^2) / fitted$totalvar

data.frame(
      PC = 1:length(fitted$sdev),
      cum_var = cumsum(fitted$sdev^2) / fitted$totalvar
) |>
      ggplot(aes(PC, cum_var)) +
      geom_line() +
      geom_point()

## ----eval = FALSE-------------------------------------------------------------
# tic()
# corpus |>
#       cast_sparse(doc, word, n) |>
#       stm(K = 0, verbose = FALSE)
# toc()

## -----------------------------------------------------------------------------
tic()
fitted_stm = corpus |>
      cast_sparse(doc, word, n) |>
      stm(K = k, verbose = FALSE)
toc()

## -----------------------------------------------------------------------------
## beta: fitted varimax loadings, transformed to probability distributions
beta = tidy(fitted, k, 'beta')

## -----------------------------------------------------------------------------
## For convenience, arrange phi in a tidy-like format
phi_tidy = phi |>
      t() |>
      as_tibble(
            rownames = 'token',
            .name_repair = tmfast:::make_colnames
      ) |>
      pivot_longer(
            starts_with('V'),
            names_to = 'topic',
            values_to = 'beta'
      ) |>
      mutate(beta_rn = beta) |>
      mutate(type = 'true')

## Compare Zipfian distributions
bind_rows(
      mutate(beta, type = 'fitted'),
      phi_tidy
) |>
      group_by(type, topic) |>
      mutate(rank = rank(desc(beta))) |>
      arrange(type, topic, rank) |>
      filter(rank < vocab / 2) |>
      ggplot(aes(rank, beta, color = type, group = interaction(topic, type))) +
      geom_line() +
      scale_y_log10() +
      scale_x_log10()

## -----------------------------------------------------------------------------
phi_tidy |>
      group_by(topic) |>
      summarize(entropy = entropy(beta))

beta |>
      group_by(topic) |>
      summarize(entropy = entropy(beta))

## -----------------------------------------------------------------------------
expected_entropy(word_beta, k = vocab)

## -----------------------------------------------------------------------------
beta_power = target_power(
      tidy_df = beta,
      group_col = topic,
      p_col = beta,
      target_entropy = expected_entropy(word_beta, k = vocab)
)
beta_power

## -----------------------------------------------------------------------------
## Renormalized beta
beta_rn = tidy(fitted, k, 'beta', exponent = beta_power)

## Entropies after renormalization
beta_rn |>
      group_by(topic) |>
      summarize(entropy = entropy(beta))

## Compare Zipfian distributions
bind_rows(
      mutate(beta_rn, type = 'fitted'),
      phi_tidy
) |>
      group_by(type, topic) |>
      mutate(rank = rank(desc(beta))) |>
      arrange(type, topic, rank) |>
      filter(rank < vocab / 2) |>
      ggplot(aes(rank, beta, color = type, group = interaction(topic, type))) +
      geom_line() +
      scale_y_log10() +
      scale_x_log10()

## -----------------------------------------------------------------------------
## Hellinger distance of word-topic distributions
beta_mx = beta_rn |>
      ## Fix order of words
      mutate(token = as.integer(token)) |>
      arrange(token) |>
      ## And dropped words
      complete(token = 1:vocab, topic, fill = list(beta = 0)) |>
      pivot_wider(
            names_from = 'topic',
            values_from = 'beta',
            values_fill = 0,
            names_sort = TRUE
      ) |>
      # select(-`NA`) |>
      ## Coerce to matrix
      column_to_rownames('token') |>
      as.matrix()

hellinger(phi, t(beta_mx))

## -----------------------------------------------------------------------------
## Use lpSolve to match fitted topics to true topics
dist = hellinger(phi, t(beta_mx))
soln = lp.assign(dist)
soln$solution

hellinger(phi, soln$solution %*% t(beta_mx))
hellinger(phi, soln$solution %*% t(beta_mx)) |>
      diag() |>
      summary()

## -----------------------------------------------------------------------------
beta_stm_mx = tidy(fitted_stm, matrix = 'beta') |>
      ## Fix order of words
      mutate(term = as.integer(term)) |>
      arrange(term) |>
      ## And dropped words
      complete(term = 1:vocab, topic, fill = list(beta = 0)) |>
      pivot_wider(
            names_from = 'topic',
            values_from = 'beta',
            values_fill = 0,
            names_sort = TRUE
      ) |>
      # select(-`NA`) |>
      ## Coerce to matrix
      column_to_rownames('term') |>
      as.matrix()

hellinger(phi, t(beta_stm_mx))

rotation_stm = hellinger(phi, t(beta_stm_mx)) |>
      lp.assign() |>
      magrittr::extract2('solution')

hellinger(phi, rotation_stm %*% t(beta_stm_mx)) |>
      diag() |>
      summary()

## -----------------------------------------------------------------------------
## Example Dirichlet distribution used in drawing the true topic-document distributions
peak_alpha(k, 1, topic_peak, topic_scale)
expected_entropy(peak_alpha(k, 1, topic_peak, topic_scale))

## Renormalization exponent
gamma_power = tidy(fitted, k, 'gamma') |>
      target_power(
            document,
            gamma,
            expected_entropy(peak_alpha(k, 1, topic_peak, topic_scale))
      )
gamma_power

gamma_df = tidy(
      fitted,
      k,
      'gamma',
      rotation = soln$solution,
      exponent = gamma_power
)

gamma_df

## -----------------------------------------------------------------------------
gamma_df |>
      mutate(document = as.integer(document)) |>
      ggplot(aes(document, topic, fill = gamma)) +
      geom_raster() +
      scale_x_continuous(breaks = NULL)

gamma_df |>
      mutate(
            document = as.integer(document),
            journal = (document - 1) %/% Mj + 1
      ) |>
      ggplot(aes(topic, gamma, group = document, color = as.factor(journal))) +
      geom_line(alpha = .25) +
      facet_wrap(vars(journal), scales = 'free_x') +
      scale_color_discrete(guide = 'none')

## -----------------------------------------------------------------------------
hellinger(
      theta_df,
      id1 = doc,
      topics2 = gamma_df,
      id2 = document,
      prob2 = gamma,
      df = FALSE
) |>
      diag() |>
      summary()

## -----------------------------------------------------------------------------
fitted_stm_gamma = tidy(fitted_stm, matrix = 'gamma') |>
      pivot_wider(names_from = 'topic', values_from = 'gamma') |>
      column_to_rownames('document') |>
      as.matrix()

hellinger(theta, fitted_stm_gamma %*% t(rotation_stm)) |>
      diag() |>
      summary()

## -----------------------------------------------------------------------------
tsne(fitted, k) |>
      mutate(journal = (as.integer(document) - 1) %/% Mj + 1) |>
      ggplot(aes(x, y, color = as.character(journal))) +
      geom_point() +
      labs(title = 't-SNE visualization')

umap(fitted, k, df = TRUE) |>
      mutate(journal = (as.integer(document) - 1) %/% Mj + 1) |>
      ggplot(aes(x, y, color = as.character(journal))) +
      geom_point() +
      labs(title = 'UMAP visualization')

