6 LDA Inference

6.1 General Overview

We have talked about LDA as a generative model, but now it is time to flip the problem around. What if I have a bunch of documents and I want to infer topics? To solve this problem we will be working under the assumption that the documents were generated using a generative model similar to the ones in the previous section. Under this assumption we need to attain the answer for Equation (6.1). (NOTE: The derivation for LDA inference via Gibbs Sampling is taken from (Darling 2011), (Heinrich 2008) and (Steyvers and Griffiths 2007).)

\[ \begin{equation} p(\theta, \phi, z|w, \alpha, \beta) = {p(\theta, \phi, z, w|\alpha, \beta) \over p(w|\alpha, \beta)} \tag{6.1} \end{equation} \] The left side of Equation (6.1) defines the following:
The probability of the document topic distribution, the word distribution of each topic, and the topic labels given all words (in all documents) and the hyperparameters \(\alpha\) and \(\beta\). In particular we are interested in estimating the probability of topic (z) for a given word (w) (and our prior assumptions, i.e. hyperparameters) for all words and topics. From this we can infer \(\phi\) and \(\theta\).

Equation (6.1) is based on the following statistical property:

\[ p(A, B | C) = {p(A,B,C) \over p(C)} \]

All the variables used in this section were outlined at the beginning of Chapter 5 if you need a refresher.

Let’s take a step from the math and map out ‘variables we know’ versus the ‘variables we don’t know’ in regards to the inference problem:

Known Parameters

  • Documents (d in D): We have a set number of documents we want to identify the topic strucutres in.
  • Words (w in W): We have a collection of words and word counts for each document.
  • Vocabulary (W): The unique list of words across all documents.
  • Hyperparemeters:
    • \(\overrightarrow{\alpha}\): Our prior assumption about the topic distribution of our documents. This book will only use symmetric \(\alpha\) values, in other words we assume all topics are equally as probably in any given document (similar to the naive assumption of a fair die). We will be supplying the \(\alpha\) value for inference.
      • Higher \(\overrightarrow{\alpha}\) - We assume documents will have a similar and close to uniform distribution of topics.
      • Lower \(\overrightarrow{\alpha}\) - We assume document topic distributions vary more drastically.
    • \(\overrightarrow{\beta}\): Our prior assumption about the word distribution of each topic.
      • Higher \(\overrightarrow{\beta}\): Word distributions in each topic are closer to uniform, i.e. each word is equally as likely in each topic.
      • Lower \(\overrightarrow{\beta}\): Word distributions vary more from topic to topic.

And on to the parts we don’t know….

Unknown (Latent) Parameters

  • Number of Topics (k): We need to specify the number of topics we assume are present in the documents. However we don’t know the real number of topics in the corpus. Methods for estimating the number of topics in a corpus are outside the scope of this book. If you are interested in learning more see (Griffiths and Steyvers 2004) and (Teh et al. 2005).
  • Document Topic Mixture (\(\theta\)): We need to determine the topic distribution in each document.
  • Word Distribution of Each Topic (\(\phi\)): We need to know the distribution of words in each topic. Obviously some words are going to occur very often in a topic while others may have zero probability of occurring in a topic.
  • Word topic assignment (z): This is actually the main thing we need to infer. To be clear, if we know the topic assignment of every word in every document, then we can derive the document topic mixture, \(\theta\), and the word distribution, \(\phi\), of each topic.

6.2 Mathematical Derivations for Inference

Back to the math…

The derivation connecting equation (6.1) to the actual Gibbs sampling solution to determine z for each word in each document, \(\overrightarrow{\theta}\), and \(\overrightarrow{\phi}\) is very complicated and I’m going to gloss over a few steps. For complete derivations see (Heinrich 2008) and (Carpenter 2010).

As stated previously, the main goal of inference in LDA is to determine the topic of each word, \(z_{i}\) (topic of word i), in each document.

\[ \begin{equation} p(z_{i}|z_{\neg i}, \alpha, \beta, w) \tag{6.2} \end{equation} \]

Notice that we are interested in identifying the topic of the current word, \(z_{i}\), based on the topic assignments of all other words (not including the current word i), which is signified as \(z_{\neg i}\).

\[ \begin{equation} \begin{aligned} p(z_{i}|z_{\neg i}, \alpha, \beta, w) &= {p(z_{i},z_{\neg i}, w, | \alpha, \beta) \over p(z_{\neg i},w | \alpha, \beta)}\\ &\propto p(z_{i}, z_{\neg i}, w | \alpha, \beta)\\ &\propto p(z,w|\alpha, \beta) \end{aligned} \tag{6.3} \end{equation} \]

You may notice \(p(z,w|\alpha, \beta)\) looks very similar to the definition of the generative process of LDA from the previous chapter (equation (5.1)). The only difference is the absence of \(\theta\) and \(\phi\). This means we can swap in equation (5.1) and integrate out \(\theta\) and \(\phi\).

\[ \begin{equation} \begin{aligned} p(w,z|\alpha, \beta) &= \int \int p(z, w, \theta, \phi|\alpha, \beta)d\theta d\phi\\ &= \int \int p(\phi|\beta)p(\theta|\alpha)p(z|\theta)p(w|\phi_{z})d\theta d\phi \\ &= \int p(z|\theta)p(\theta|\alpha)d \theta \int p(w|\phi_{z})p(\phi|\beta)d\phi \end{aligned} \tag{6.4} \end{equation} \]

As with the previous Gibbs sampling examples in this book we are going to expand equation (6.3), plug in our conjugate priors, and get to a point where we can use a Gibbs sampler to estimate our solution.

Below we continue to solve for the first term of equation (6.4) utilizing the conjugate prior relationship between the multinomial and Dirichlet distribution. The result is a Dirichlet distribution with the parameters comprised of the sum of the number of words assigned to each topic and the alpha value for each topic in the current document d.

\[ \begin{equation} \begin{aligned} \int p(z|\theta)p(\theta|\alpha)d \theta &= \int \prod_{i}{\theta_{d_{i},z_{i}}{1\over B(\alpha)}}\prod_{k}\theta_{d,k}^{\alpha k}\theta_{d} \\ &={1\over B(\alpha)} \int \prod_{k}\theta_{d,k}^{n_{d,k} + \alpha k} \\ &={B(n_{d,.} + \alpha) \over B(\alpha)} \end{aligned} \tag{6.5} \end{equation} \]

Similarly we can expand the second term of Equation (6.4) and we find a solution with a similar form. The result is a Dirichlet distribution with the parameter comprised of the sum of the number of words assigned to each topic across all documents and the alpha value for that topic.

\[ \begin{equation} \begin{aligned} \int p(w|\phi_{z})p(\phi|\beta)d\phi &= \int \prod_{d}\prod_{i}\phi_{z_{d,i},w_{d,i}} \prod_{k}{1 \over B(\beta)}\prod_{w}\phi^{B_{w}}_{k,w}d\phi_{k}\\ &= \prod_{k}{1\over B(\beta)} \int \prod_{w}\phi_{k,w}^{B_{w} + n_{k,w}}d\phi_{k}\\ &=\prod_{k}{B(n_{k,.} + \beta) \over B(\beta)} \end{aligned} \tag{6.6} \end{equation} \]

This leaves us with the following:

\[ \begin{equation} \begin{aligned} p(w,z|\alpha, \beta) &= \prod_{d}{B(n_{d,.} + \alpha) \over B(\alpha)} \prod_{k}{B(n_{k,.} + \beta) \over B(\beta)} \end{aligned} \tag{6.7} \end{equation} \]

The equation necessary for Gibbs sampling can be derived by utilizing (6.7). This is accomplished via the chain rule and the definition of conditional probability.

The chain rule is outlined in Equation (6.8)

\[ \begin{equation} p(A,B,C,D) = P(A)P(B|A)P(C|A,B)P(D|A,B,C) \tag{6.8} \end{equation} \]

The conditional probability property utilized is shown in (6.9)

\[ \begin{equation} P(B|A) = {P(A,B) \over P(A)} \tag{6.9} \end{equation} \]

We will now use Equation (6.10) in the example below to complete the LDA Inference task on a random sample of documents.

\[ \begin{equation} \begin{aligned} p(z_{i}|z_{\neg i}, w) &= {p(w,z)\over {p(w,z_{\neg i})}} = {p(z)\over p(z_{\neg i})}{p(w|z)\over p(w_{\neg i}|z_{\neg i})p(w_{i})}\\ \\ &\propto \prod_{d}{B(n_{d,.} + \alpha) \over B(n_{d,\neg i}\alpha)} \prod_{k}{B(n_{k,.} + \beta) \over B(n_{k,\neg i} + \beta)}\\ \\ &\propto {\Gamma(n_{d,k} + \alpha_{k}) \Gamma(\sum_{k=1}^{K} n_{d,\neg i}^{k} + \alpha_{k}) \over \Gamma(n_{d,\neg i}^{k} + \alpha_{k}) \Gamma(\sum_{k=1}^{K} n_{d,k}+ \alpha_{k})} {\Gamma(n_{k,w} + \beta_{w}) \Gamma(\sum_{w=1}^{W} n_{k,\neg i}^{w} + \beta_{w}) \over \Gamma(n_{k,\neg i}^{w} + \beta_{w}) \Gamma(\sum_{w=1}^{W} n_{k,w}+ \beta_{w})}\\ \\ &\propto (n_{d,\neg i}^{k} + \alpha_{k}) {n_{k,\neg i}^{w} + \beta_{w} \over \sum_{w} n_{k,\neg i}^{w} + \beta_{w}} \end{aligned} \tag{6.10} \end{equation} \]

To calculate our word distributions in each topic we will use Equation (6.11).

\[ \begin{equation} \phi_{k,w} = { n^{(w)}_{k} + \beta_{w} \over \sum_{w=1}^{W} n^{(w)}_{k} + \beta_{w}} \tag{6.11} \end{equation} \]

The topic distribution in each document is calcuated using Equation (6.12).

\[ \begin{equation} \theta_{d,k} = {n^{(k)}_{d} + \alpha_{k} \over \sum_{k=1}^{K}n_{d}^{k} + \alpha_{k}} \tag{6.12} \end{equation} \]

What if I don’t want to generate docuements. What if my goal is to infer what topics are present in each document and what words belong to each topic? This is were LDA for inference comes into play.

Before going through any derivations of how we infer the document topic distributions and the word distributions of each topic, I want to go over the process of inference more generally.

The General Idea of the Inference Process

  1. Initialization: Randomly select a topic for each word in each document from a multinomial distribution.
  2. Gibbs Sampling:
  • For i iterations
  • For document d in documents:
    • For each word in document d:
      • assign a topic to the current word based on probability of the topic given the topic of all other words (except the current word) as shown in Equation (6.10)

6.3 Animal Farm - Code Example

Now let’s revisit the animal example from the first section of the book and break down what we see. This time we will also be taking a look at the code used to generate the example documents as well as the inference code.

rm(list = ls())
library(MCMCpack)
library(tidyverse)
library(Rcpp)
library(knitr)
library(kableExtra) 
library(lsa) 


get_topic <- function(k){ 
  which(rmultinom(1,size = 1,rep(1/k,k))[,1] == 1)
} 

get_word <- function(theta, phi){
  topic <- which(rmultinom(1,1,theta)==1)
  # sample word from topic
  new_word <- which(rmultinom(1,1,phi[topic, ])==1)
  return(c(new_word, topic))  
}

cppFunction(
  'List gibbsLda(  NumericVector topic, NumericVector doc_id, NumericVector word,
  NumericMatrix n_doc_topic_count,NumericMatrix n_topic_term_count,
  NumericVector n_topic_sum, NumericVector n_doc_word_count){
  
  int alpha = 1;
  int beta  = 1;
  int cs_topic,cs_doc, cs_word, new_topic;
  int n_topics = max(topic)+1;
  int vocab_length = n_topic_term_count.ncol();
  double p_sum = 0,num_doc, denom_doc, denom_term, num_term;
  NumericVector p_new(n_topics);
  IntegerVector topic_sample(n_topics);
  
  for (int iter  = 0; iter < 100; iter++){
  for (int j = 0; j < word.size(); ++j){
  // change values outside of function to prevent confusion
  cs_topic = topic[j];
  cs_doc   = doc_id[j];
  cs_word  = word[j];
  
  // decrement counts      
  n_doc_topic_count(cs_doc,cs_topic) = n_doc_topic_count(cs_doc,cs_topic) - 1;
  n_topic_term_count(cs_topic , cs_word) = n_topic_term_count(cs_topic , cs_word) - 1;
  n_topic_sum[cs_topic] = n_topic_sum[cs_topic] -1;
  
  // get probability for each topic, select topic with highest prob
  for(int tpc = 0; tpc < n_topics; tpc++){
  
  // word cs_word topic tpc + beta
  num_term   = n_topic_term_count(tpc, cs_word) + beta;
  // sum of all word counts w/ topic tpc + vocab length*beta
  denom_term = n_topic_sum[tpc] + vocab_length*beta;
  
  
  // count of topic tpc in cs_doc + alpha
  num_doc    = n_doc_topic_count(cs_doc,tpc) + alpha;
  // total word count in cs_doc + n_topics*alpha
  denom_doc = n_doc_word_count[cs_doc] + n_topics*alpha;
  
  p_new[tpc] = (num_term/denom_term) * (num_doc/denom_doc);
  
  }
  // normalize the posteriors
  p_sum = std::accumulate(p_new.begin(), p_new.end(), 0.0);
  for(int tpc = 0; tpc < n_topics; tpc++){
  p_new[tpc] = p_new[tpc]/p_sum;
  }
  // sample new topic based on the posterior distribution
  R::rmultinom(1, p_new.begin(), n_topics, topic_sample.begin());
  
  for(int tpc = 0; tpc < n_topics; tpc++){
  if(topic_sample[tpc]==1){
  new_topic = tpc;
  }
  }
  
  // print(new_topic)
  // update counts
  n_doc_topic_count(cs_doc,new_topic) = n_doc_topic_count(cs_doc,new_topic) + 1;
  n_topic_term_count(new_topic , cs_word) = n_topic_term_count(new_topic , cs_word) + 1;
  n_topic_sum[new_topic] = n_topic_sum[new_topic] + 1;
  
  
  // update current_state
  topic[j] = new_topic;
  
  }
  
  }
  return List::create(
  n_topic_term_count,
  n_doc_topic_count);
  }
  ')
# 3 topics - land sea & air
# birds and amphibious have cross over
# fish - sea 100
# land animals - 100 land

beta <- 1
k <- 3 # number of topics
M <- 100 # let's create 10 documents
alphas <- rep(1,k) # topic document dirichlet parameters
xi <- 100 # average document length 
N <- rpois(M, xi) #words in each document


# whale1, whale2, FISH1, FISH2,OCTO
sea_animals <- c('\U1F40B', '\U1F433','\U1F41F', '\U1F420', '\U1F419')

# crab, alligator, TURTLE,SNAKE
amphibious  <- c('\U1F980', '\U1F40A', '\U1F422', '\U1F40D')

# CHICKEN, TURKEY, DUCK, PENGUIN
birds       <- c('\U1F413','\U1F983','\U1F426','\U1F427')
# SQUIRREL, ELEPHANT, COW, RAM, CAMEL
land_animals<- c('\U1F43F','\U1F418','\U1F402','\U1F411','\U1F42A')

vocab <- c(sea_animals, amphibious, birds, land_animals)

# equal probability 1/18
# 0 - animals that are not possible
# 1 - for shared
# 4 - non-shared
shared <- 2
non_shared <- 4
not_present <- 0

land_phi <- c(rep(not_present, length(sea_animals)),
              rep(shared, length(amphibious)),
              rep(non_shared, 2), # turkey and chicken can't fly
              rep(shared, 2), # regular bird and pengiun
              rep(non_shared, length(land_animals)))
land_phi <- land_phi/sum(land_phi)


sea_phi <- c(rep(non_shared, length(sea_animals)),
             rep(shared, length(amphibious)),
             rep(not_present, 2), # turkey and chicken can't fly 
             rep(shared, 2), # regular bird and pengiun 
             rep(not_present, length(land_animals)))
sea_phi <- sea_phi/sum(sea_phi)

air_phi <- c(rep(not_present, length(sea_animals)),
             rep(not_present, length(amphibious)),
             rep(not_present, 2), # turkey and chicken can't fly 
             non_shared, # regular bird
             not_present, # penguins can't fly
             rep(not_present, length(land_animals)))
air_phi <- air_phi/sum(air_phi)

# calculate topic word distributions
phi <- matrix(c(land_phi, sea_phi, air_phi), nrow = k, ncol = length(vocab), 
              byrow = TRUE, dimnames = list(c('land', 'sea', 'air')))

theta_samples <- rdirichlet(M, alphas)
thetas <- theta_samples[rep(1:nrow(theta_samples), times = N), ]
new_words <- t(apply(thetas, 1, function(x) get_word(x,phi)))

ds <-tibble(doc_id = rep(1:length(N), times = N), 
            word   = new_words[,1],
            topic  = new_words[,2], 
            theta_a = thetas[,1],
            theta_b = thetas[,2],
            theta_c = thetas[,3]
) 

ds %>% filter(doc_id < 3) %>% group_by(doc_id) %>% summarise(
  tokens = paste(vocab[word], collapse = ' ')
) %>% kable(col.names = c('Document', 'Animals'), 
            caption ="Animals at the First Two Locations")
Table 6.1: Animals at the First Two Locations
Document Animals
1 🐑 🐋 🐂 🐓 🐑 🦃 🦃 🐪 🦃 🐓 🐑 🐂 🐋 🐪 🐢 🐂 🐑 🐓 🐍 🐑 🐟 🐍 🐙 🐿 🐪 🐘 🐑 🐙 🦃 🐘 🐑 🐂 🐂 🐊 🐂 🐘 🐿 🦃 🦃 🐧 🐪 🐍 🐿 🐘 🐍 🐂 🐙 🐦 🐂 🐟 🐍 🐠 🐑 🦃 🐓 🦃 🐂 🐘 🐠 🐓 🐢 🐘 🦃 🐙 🐘 🐍 🐢 🐑 🐍 🐘 🐦 🐧 🐙 🐦 🐂 🐋 🐘 🐠 🐓 🐿 🐘 🐿 🦀 🐓 🐿 🐊 🐓 🦀 🐂 🐪 🐓 🐠 🐪 🐦 🐙 🐠 🐘 🦃 🐘 🐑 🐊 🐑 🐳
2 🐠 🐟 🐋 🐋 🐢 🐳 🦀 🐦 🐘 🐦 🐓 🐦 🐦 🦃 🐦 🐦 🐂 🦀 🐢 🦀 🐊 🐿 🦃 🐳 🐦 🐠 🐪 🐍 🐠 🐪 🐧 🐂 🐿 🐋 🐦 🐦 🐦 🐿 🐋 🐿 🐠 🐦 🐳 🐂 🐦 🐙 🐟 🐑 🐪 🐪 🦀 🐑 🐊 🐦 🐦 🐿 🦃 🐙 🐿 🦃 🐓 🐊 🐦 🐢 🐙 🐊 🦀 🐙 🐍 🐓 🐍 🐳 🐋 🐙 🐧 🐢 🐋 🐘 🐑 🐍 🐑 🐳 🐦 🐠 🐂 🐧 🐙 🐍

The habitat (topic) distributions for the first couple of documents:

Table 6.2: Distribution of Habitats in the First Two Locations
Document Land Sea Air
1 0.8168897 0.1688578 0.0142525
2 0.3860697 0.4442854 0.1696449

With the help of LDA we can go through all of our documents and estimate the topic/word distributions and the topic/document distributions.

This is our estimated values and our resulting values:

######### Inference ############### 
current_state <- ds %>% dplyr::select(doc_id, word, topic)
current_state$topic <- NA
t <- length(unique(current_state$word))

# n_doc_topic_count  
n_doc_topic_count <- matrix(0, nrow = M, ncol = k)
# document_topic_sum
n_doc_topic_sum  <- rep(0,M)
# topic_term_count
n_topic_term_count <- matrix(0, nrow = k, ncol = t)
# colnames(n_topic_term_count) <- unique(current_state$word)
# topic_term_sum
n_topic_sum  <- rep(0,k)
p <- rep(0, k)
# initialize topics

current_state$topic <- replicate(nrow(current_state),get_topic(k))

# get word, topic, and document counts (used during inference process)
n_doc_topic_count <- current_state %>% group_by(doc_id, topic) %>%
  summarise(
    count = n()
  ) %>% spread(key = topic, value = count) %>% as.matrix()

n_topic_sum <- current_state %>% group_by(topic) %>%
  summarise(
    count = n()
  )  %>% select(count) %>% as.matrix() %>% as.vector()

n_topic_term_count <- current_state %>% group_by(topic, word) %>% 
  summarise(
    count = n()
  ) %>% spread(word, count) %>% as.matrix()



# minus 1 in, add 1 out
lda_counts <- gibbsLda( current_state$topic-1 , current_state$doc_id-1, current_state$word-1,
                        n_doc_topic_count[,-1], n_topic_term_count[,-1], n_topic_sum, N)


# calculate estimates for phi and theta

# phi - row apply to lda_counts[[1]]

# rewrite this function and normalize by row so that they sum to 1
phi_est <- apply(lda_counts[[1]], 1, function(x) (x + beta)/(sum(x)+length(vocab)*beta) )
rownames(phi_est) <- vocab
colnames(phi) <- vocab
theta_est <- apply(lda_counts[[2]],2, function(x)(x+alphas[1])/(sum(x) + k*alphas[1]))
theta_est <- t(apply(theta_est, 1, function(x) x/sum(x)))


# rewrite this function and normalize by row so that they sum to 1
phi_est <- apply(lda_counts[[1]], 1, function(x) (x + beta)/(sum(x)+length(vocab)*beta) )
rownames(phi_est) <- vocab
colnames(phi) <- vocab
theta_est <- apply(lda_counts[[2]],2, function(x)(x+alphas[1])/(sum(x) + k*alphas[1]))
theta_est <- t(apply(theta_est, 1, function(x) x/sum(x)))

colnames(theta_samples) <- c('land', 'sea', 'air')
vector_angles <- cosine(cbind(theta_samples,theta_est))[4:6, 1:3]
estimated_topic_names <- apply(vector_angles, 1, function(x)colnames(vector_angles)[which.max(x)])

phi_table <- as.tibble(t(round(phi,2))[,estimated_topic_names])

phi_table <- cbind(phi_table, as.tibble(round(phi_est, 2)))
# names(theta_table)[4:6] <- paste0(estimated_topic_names, ' estimated')
# theta_table <- theta_table[, c(4,1,5,2,6,3)]

names(phi_table)[4:6] <- paste0(estimated_topic_names, ' estimated')
phi_table <- phi_table[, c(4,1,5,2,6,3)]
row.names(phi_table) <- colnames(phi)

kable(round(phi_table, 2), caption = 'True and Estimated Word Distribution for Each Topic')
Table 6.3: True and Estimated Word Distribution for Each Topic
air estimated air land estimated land sea estimated sea
🐋 0.00 0 0.00 0.00 0.12 0.12
🐳 0.00 0 0.02 0.00 0.12 0.12
🐟 0.00 0 0.01 0.00 0.12 0.12
🐠 0.00 0 0.01 0.00 0.12 0.12
🐙 0.00 0 0.01 0.00 0.13 0.12
🦀 0.00 0 0.05 0.05 0.06 0.06
🐊 0.01 0 0.04 0.05 0.06 0.06
🐢 0.00 0 0.05 0.05 0.07 0.06
🐍 0.01 0 0.06 0.05 0.05 0.06
🐓 0.01 0 0.09 0.10 0.00 0.00
🦃 0.00 0 0.11 0.10 0.00 0.00
🐦 0.92 1 0.01 0.05 0.08 0.06
🐧 0.01 0 0.05 0.05 0.06 0.06
🐿 0.01 0 0.09 0.10 0.01 0.00
🐘 0.00 0 0.10 0.10 0.00 0.00
🐂 0.00 0 0.11 0.10 0.00 0.00
🐑 0.00 0 0.10 0.10 0.00 0.00
🐪 0.00 0 0.11 0.10 0.00 0.00

The document topic mixture estimates are shown below for the first 5 documents:

Table 6.4: The Estimated Topic Distributions for the First 5 Documents
Location air estimated air land estimated land sea estimated sea
1 0.04 0.01 0.90 0.82 0.06 0.17
2 0.19 0.17 0.48 0.39 0.33 0.44
3 0.31 0.40 0.13 0.18 0.56 0.42
4 0.82 0.64 0.17 0.33 0.01 0.03
5 0.31 0.29 0.31 0.41 0.39 0.30
6 0.35 0.32 0.11 0.15 0.54 0.53