4 min read

[Scala/Spark] Get topics' words from the LDA model.

Some time ago I had to move from sparklyr to Scala for better integration with Spark, and easier collaboration with other developers in a team. Interestingly, this conversion was much easier than I thought because Spark’s DataFrame API is somewhat similar to dplyr, there’s groupBy function, agg instead of summarise, and so on. You can also use traditional, old SQL to operate on data frames. Anyway, in this post, I’ll show how to fit very simple LDA (Latent Dirichlet allocation) model, and then extract information about topic’s words. For some reason, this is a bit more complicated than I thought…

LDA is a topic model, which allows extracting abstract topics from multiple documents. For example in the case when the document is mostly about machine learning in R (about 90%) and only a small part of the text is about Python, there should be higher probability of finding more R’s words like dplyr, caret or mlr, than Python’s counterparts. I don’t want to dive into the details of the model, for more information, please refer to Wikipedia page https://en.wikipedia.org/wiki/Topic_model and https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation.

Fitting LDA in Spark.

There’s a sample code for fitting LDA in Spark’s documentation - https://spark.apache.org/docs/2.2.0/ml-clustering.html#latent-dirichlet-allocation-lda, but for my purpose, I’ll use my version.

At the beginning of the code, you need to put some import declarations and create some data.

import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel}
import org.apache.spark.ml.clustering.LDA
import org.apache.spark.sql.functions.udf
import scala.collection.mutable.WrappedArray

val txt = Array("A B B C", "A B D D", "A C D")

The txt is a local variable, so for now, we are not using Spark. The next listing shows how to upload the array to the cluster (most of the time the data is read from HDFS, so other functions are used), and then preprocess the data for the fitting process. LDA does not take raw text, but it requires a DataFrame with sparse vectors containing the number of occurrences of each word in the document. Fortunately, this is quite an easy transformation, because Spark contains CountVectorizer which does that job.

val txtDf     = spark.sparkContext.parallelize(txt).toDF("txt")
val txtDfSplit = txtDf.withColumn("txt", split(col("txt"), " "))

// create sparse vector with the number 
// of occurrences of each word using CountVectorizer
val cvModel = new CountVectorizer()
  .setInputCol("txt")
  .setOutputCol("features")
  .setVocabSize(4)
  .setMinDF(2)
  .fit(txtDfSplit)
  
val txtDfTrain = cvModel.transform(txtDfSplit)
txtDfTrain.show(false) //show the DataFrame content

// +------------+-------------------------+
// |txt         |features                 |
// +------------+-------------------------+
// |[A, B, B, C]|(4,[1,2,3],[2.0,1.0,1.0])|
// |[A, B, D, D]|(4,[0,1,2],[2.0,1.0,1.0])|
// |[A, C, D]   |(4,[0,2,3],[1.0,1.0,1.0])|
// +------------+-------------------------+

Then fitting the LDA model is just one line:

val lda = new LDA().setK(2).setMaxIter(10).fit(txtDfTrain)

To get the words related to each topic the describeTopics method is used. However it doesn’t return words, but their indices from the dictionary created by CountVectorizer:

// vocabulary created by CountVectorizer
val vocab = spark.sparkContext.broadcast(cvModel.vocabulary) 
// describeTopics output:
lda.describeTopics(4).show

// +-----+------------+--------------------+
// |topic| termIndices|         termWeights|
// +-----+------------+--------------------+
// |    0|[2, 1, 0, 3]|[0.29972753357517...|
// |    1|[1, 3, 0, 2]|[0.27815048189882...|
// +-----+------------+--------------------+

User-defined function.

To map the indices to the proper word the custom function must be used. Standard Scala functions can’t operate on the DataFrame, columns, so the we need to create Spark’s user-defined function - udf for short. It takes a WrappedArray of integers (because the termIndices cell is, in fact, a WrappedArray[Int]), and returns the array of words extracted from the dictionary.

val toWords = udf( (x : WrappedArray[Int]) => { x.map(i => vocab.value(i)) })
val topics = lda.describeTopics(4)
        .withColumn("topicWords", toWords(col("termIndices")))
topics.select("topicWords").show

// +------------+
// |  topicWords|
// +------------+
// |[A, B, D, C]|
// |[B, C, D, A]|
// +------------+

The last task it might be to extract the data in the tidy format:

// +-----+----+-------------------+
// |topic|word|             weight|
// +-----+----+-------------------+
// |    0|   A|0.29972753357517407|
// |    0|   B| 0.2459495088520157|
// |    0|   D|0.22775669710607507|
// |    0|   C|0.22656626046673525|
// |    1|   B| 0.2781504818988264|
// |    1|   C| 0.2744447750388326|
// |    1|   D| 0.2457615751570422|
// |    1|   A|0.20164316790529874|
// +-----+----+-------------------+

To achieve this task a bit more complicated udf function must be used. It merges words, and their weights in one column. Then its easy to explode (it’s the unnest operation in tidyr) the result and split into two columns, one for the term, and second for the weight.

val wordsWithWeights = udf( (x : WrappedArray[Int],
                             y : WrappedArray[Double]) => 
    { x.map(i => vocab.value(i)).zip(y)}
)

val topics2 = lda.describeTopics(4)
    .withColumn("topicWords", 
      wordsWithWeights(col("termIndices"), col("termWeights"))
    )
topics2.select("topicWords").show(false)

// +--------------------------------------------+
// |topicWords                                  |
// +--------------------------------------------+
// |[[A,0.299], [B,0.245], [D,0.227], [C,0.226]]|
// |[[B,0.278], [C,0.274], [D,0.245], [A,0.201]]|
// +--------------------------------------------+


val topics2exploded = topics2
    .select("topic", "topicWords")
    .withColumn("topicWords", explode(col("topicWords")))
topics2exploded.show

// +-----+--------------------+
// |topic|          topicWords|
// +-----+--------------------+
// |    0|[A,0.299727533575...|
// |    0|[B,0.245949508852...|
// |    0|[D,0.227756697106...|
// |    0|[C,0.226566260466...|
// |    1|[B,0.278150481898...|
// |    1|[C,0.274444775038...|
// |    1|[D,0.245761575157...|
// |    1|[A,0.201643167905...|
// +-----+--------------------+

val finalTopic = topics2exploded
  .select(
    col("topic"), 
    col("topicWords").getField("_1").as("word"), 
    col("topicWords").getField("_2").as("weight")
  )
finalTopic.show

// +-----+----+-------------------+
// |topic|word|             weight|
// +-----+----+-------------------+
// |    0|   A|0.29972753357517407|
// |    0|   B| 0.2459495088520157|
// |    0|   D|0.22775669710607507|
// |    0|   C|0.22656626046673525|
// |    1|   B| 0.2781504818988264|
// |    1|   C| 0.2744447750388326|
// |    1|   D| 0.2457615751570422|
// |    1|   A|0.20164316790529874|
// +-----+----+-------------------+

And that’s all for this post. I hope it will be helpful:)

comments powered by Disqus