Peering into the Black Box : Visualizing LambdaMART

In the last post, I gave a broad overview of the Learning to Rank domain of machine learning that has applications in web search, machine translation, and question-answering systems. In this post, we’ll look at a state of the art model used in Learning to Rank called LambdaMART. We’ll take a look at some math underlying LambdaMART, then focus on developing ways to visualize the model.

LambdaMART produces a tree ensemble model, a class of models traditionally viewed as ‘black boxes’ since they take into account predictions of 10’s or 100’s of underlying trees. While viewing a single decision tree is intuitive and interpretable, viewing 100 at once would be overwhelming and uninformative:

tree

 

Single Tree

 

Ensemble of Trees

 Ensemble of Trees

By moving from a single tree to an ensemble of trees, we tradeoff interpretability for performance.

In this post, we’ll take a look at the trees produced by the LambdaMART training process, and develop a visualization to view the ensemble of trees as a single collective unit. By doing so, we’ll gain back some of the intuitive interpretability that make tree models appealing. Through the visualization, we’ll peer into the black box model and gain some sense of the factors that the model uses to make predictions.

We use Java and the JForests library to train LambdaMART and parse its output, and d3.js for visualization. To get a sneak peek, the final visualization is here.

LambdaMART Overview

Let’s take a look at some details of LambdaMART. For the full story, check out this paper from Microsoft Research.

At a high level, LambdaMART is an algorithm that uses gradient boosting to directly optimize Learning to Rank specific cost functions such as NDCG. To understand LambdaMART we’ll look at two aspects: Lambda and MART.

MART

LambdaMART is a specific instance of Gradient Boosted Regression Trees, also referred to as Multiple Additive Regression Trees (MART). Gradient Boosting is a technique for forming a model that is a weighted combination of an ensemble of “weak learners”. In our case, each “weak learner” is a decision tree.

Our goal is to find a function f(x) that minimizes the expected loss L:

\hat{f}(x) = \arg\min_{f(x)}E[L(y, f(x))|x]

for feature vectors x and labels y.

To do so we first view \hat{f}(x) as a sum of weak learners f_m:

\hat{f}(x)=\sum_{m=1}^M f_{m}(x)

Since we are dealing with decision trees, evaluating f_{m}(x) corresponds to passing x down the tree f_{m} until it reaches a leaf node l. The predicted value is then equal to a parameter \gamma_l.

Decision trees consist of two types of parameters: region assignments R_i that assign a training example x_i to a leaf node l, and leaf outputs \gamma_l that represent the tree’s output for all examples assigned to region R_l. Hence we can write each weak learner f_{m}(x) as a parameterized tree h_{m}(x;R_{m}, \gamma_{m}), giving:

\hat{f}(x)=\sum_{m=1}^M f_{m}(x)=\sum_{m=1}^M h_{m}(x; R_{m}, \gamma_{m})

Gradient Boosting finds each tree in a stepwise fashion. We start with an initial tree f_1, then step to another tree f_2, and so on:

\hat{f}_{1}(x) = f_{1}

\hat{f}_{2}(x) = f_{1} + f_{2}

\hat{f}_{M}(x) = \sum_{m=1}^{M}f_{m}

How do we determine the steps? At each step, we’d like the model to change such that the loss decreases as much as possible. The locally optimal decrease corresponds to the gradient of the loss with respect to the current model m:

g_{im}=\frac{\partial L[f(x_i), y_i]}{\partial f(x_i)}

Hence we define the step as f_m=-\rho_mg_m. This gives us insight into how Gradient Boosting solves the minimization – the algorithm is performing Gradient Descent in function space.

Now, we know that we want to take a gradient step, but still haven’t said how that translates to finding a tree. To do so, we build a tree that models the gradient; by adding such a tree to the ensemble, we effectively take a gradient step from the previous model. To fit the tree, we use squared error loss, giving the minimization problem:

argmin_{R, \gamma}\sum_{i=1}^N(-g_{im}-F(x_i;R, \gamma))^2

Where F(x_i;R, \gamma) denotes a regression tree with parameters R, \gamma.

Let’s take a step back. We’ve just set up a framework for finding an ensemble of trees that, when added together, minimizes a loss function. It’s a “framework” in the sense that we can use it if we merely supply gradients of the loss function at each training point. This is where the “Lambda” part of LambdaMART comes into play.

For further reading, Chapter 10 of Elements of Statistical Learning provides a great and thorough overview. This paper also provides a good intro of Gradient Boosting.

Lambda

In ranking, the loss function that we’ll most likely care about optimizing is probably either NDCG, MAP, or MRR. Unfortunately, these loss functions aren’t differentiable at all points, so we can’t use them directly in our Gradient Boosting “framework”, since it’s unclear how we can provide the gradients at each training point.

To address this, LambdaMART uses an idea from a model called LambdaRank. For each training point pair i,j, we compute a value \lambda_{i,j} that acts as the gradient we need. Intuitively, \lambda_{i,j} can be thought of as a force that moves documents up and down the ranked list. For instance, if document i is ranked lower than document j, then document i will receive a push of size |\lambda_{i,j}| downwards in the ranked list, and j will be pushed upwards the same amount. The LambdaMART paper reports that empirically, the \lambda‘s have been used successfully to optimize NDCG, MAP, and MRR.

Hence in LambdaMART we use the Gradient Boosting framework with the \lambda_i‘s acting as the gradients g_i. This leads to update rules for the leaf values \gamma_l that depend on the \lambda values of the training instances that fall in leaf l.

Visualizing Gradient Boosting

We’ve observed that LambdaMART trains an ensemble of trees sequentially. What does this sequence of trees look like? Let’s look at an ensemble of 19 trees trained with LambdaMART on the OHSUMED dataset:

ensemble

At each inner node, the small number denotes the feature label, and the larger number denotes the threshold. The number at each leaf node denotes the leaf output.

We can see that each tree is somewhat similar to the previous tree. This makes sense given that each tree t_i is dependent on the model state F_{i-1} since it is fit to the gradient with respect to that model’s gradient.

The step-by-step animation gives us insight into what is happening during training. It’s a great way to visualize and gain an intuition of the gradient boosting process – we see consecutive trees being produced that depend on the previous tree. However, we are still viewing the LambdaMART model as a group of separate trees; we still need to develop a way of viewing the final model cohesively.

Does the sequence animation tell us anything that could help with developing a single, cohesive model view? We can start with the observation that often, the trees look “roughly similar”.

Sometimes two trees will have the same feature at a given node position; for instance the root favors features 7, 9, and 11. Trees may not even change from step-to-step or may share a similar overall structure, such as trees #3 and #4 – notice how the numbers change, but the structure remains the same. Finally, all of the trees only draw from a subset of the total feature set; for instance feature 4 is never used by any of the trees.

Grouping Similar Nodes

Given these observations, we can think about how to build a unified visualization that takes advantage of reused features and similar structure between trees. Let’s consider the root node. We can count the frequency of each feature that occurs as the root node across the entire ensemble, leading to a single node that shows that feature 7 was used as the feature for the root node for 3 trees, feature 9 for 5 trees, and so on:

node1

At each tree level h, there are 2^h possible node positions. We can collect the feature frequencies for each of these positions across the ensemble. If a position doesn’t appear in a tree, we mark it as ‘DNE’, and if a position is a leaf, we mark it as ‘Leaf’, such as position (3, 6):

node2

Heatmap Tree

At each node, we have a list of feature (or DNE or Leaf) counts. Heatmaps provide a natural way of visualizing this count information. We can make each node of the ordered-pair-tree into a heatmap, giving rise to the “Heatmap Tree”:

ohsumed_10_1_1_1

The Heatmap Tree

To view the interactive d3.js visualization, see this link.

The Heatmap Tree shows the entire ensemble as a single unit. Each number is a feature, and the color scale tells us how many trees use that feature at the given node:

Heatmap Detail

By looking at the Heatmap Tree we can get a sense of which features the tree uses when it classifies instances. The ensemble uses features 7, 9, 11, 25, and 29 to perform the initial split, with 11 being used the most often. Further down the tree, we see features 7 and 9 again, along with common features such as 33 and 37. We can easily see that most of the trees in the ensemble are between 5 and 7 levels deep, and that while there are 46 total features, at a given node location the most variation we see is 9 features.

The Heatmap Tree can bring together hundreds of trees, such as this visualization of a 325 tree ensemble:

325

Tuning Parameters

The Heatmap Tree also lets us see how tuning parameters affect the final model. For instance, as we decrease the learning rate from 0.10 to 0.001, we see the ensemble size fluctuate:

ohsumed_10_1_1_1

Learning Rate 0.10

ohsumed_fold2_10_01_1_1

Learning Rate 0.01

ohsumed_fold2_10_001_1_1

Learning Rate 0.001

Notice how in the 0.01 case, the Heatmap Tree concisely summarized a 111-tree ensemble.

When we use feature subsampling, we see the number of features at each node increase (in general). Each tree has a different limited subset of features to choose from, leading to a spread in the overall distribution of chosen features:

ohsumed_fold2_10_05_1_06

Feature subsampling 0.6

Feature subsampling and training data subsampling makes this ensemble more crowded:

ohsumed_fold2_10_05_06_06

Feature subsampling 0.6, Data subsampling 0.6

Note that these parameter trends do not necessarily generalize. However, the Heatmap Tree captures all of the trees and features in a single structure, and gives us insight into the structural results of our parameter choices.

Limitations

The Heatmap Tree, unfortunately, has its limits. With wide variation of features and many leaves, the tree becomes crowded:

ohsumed_fold2_25

Since the number of possible node locations at a given level increases exponentially with height, the tree also suffers when trying to visualize deep trees.

Expansions

Another nice aspect of decision trees is that we can visualize how a test instance gets classified; we simply show the path it takes from root to leaf.

How could we visualize the classification process in an ensemble, via the Heatmap Tree or otherwise? With the Heatmap Tree, we would need to be able to simultaneously visualize 10’s or 100’s of paths, since there would be an individual path for every tree in the ensemble. One idea is to have weighted edges on the Heatmap Tree; an edge would become thicker each time the edge is used when classifying an instance.

Another next step is to test the generalizability of this visualization; could it work for any gradient boosting model? What would a Heatmap Tree of a random forest look like?

Conclusion

We’ve taken a close look at LambdaMART and gradient boosting. We’ve devised a way to capture the complexity of gradient boosted tree ensembles in a cohesive way. In doing so we bought back some of the interpretability that we lost by moving from a single tree to an ensemble, gained insight into the training process, and made the black box LambdaMART model a bit more transparent.

To see an interactive d3 Heatmap Tree, visit this link.

References and Further Reading

From RankNet to LambdaRank to LambdaMART: An Overview

Gradient Boosting Machines: A Tutorial

“Tree Ensembles for Learning to Rank”, Yasser Ganjisaffar 2011 PhD Thesis

Advertisements

These Are Your Tweets on LDA (Part II)

In the last post, I gave an overview of Latent Dirichlet Allocation (LDA), and walked through an application of LDA on @BarackObama’s tweets. The final product was a set of word clouds, one per topic, that showed the weighted words that defined the topic.

In this post, we’ll develop a dynamic visualization that incorporates multiple topics, allowing us to gain of a high level view of the topics and also drill down to see the words that define each topic. Through a simple web interface, we’ll also be able to view data from different twitter users.

Click here for an example of the finished product.

As before, all of the code is available on GitHub. The visualization-related code is found in the viz/static directory.

Harnessing the Data

In the last post, we downloaded tweets for a user and found 50 topics that occur in the user’s tweets along with the top 20 words for each topic. We also found the composition of topics across all of the tweets, allowing us to rank the topics by prominence. For our visualization, we’ll choose to display the 10 highest ranked topics for a given twitter user name.

We need a visualization that can show multiple groupings of data. Each of the 10 groupings has 20 words, so we’d also like one that avoids the potential information overload. Finally, we’d like to incorporate the frequencies that we have for each word.

d3

A good fit for these requirements is d3.js‘s Zoomable Pack Layout, which gives us a high level view of each grouping as a bubble. Upon clicking a bubble, we can see the data that comprises the bubble, as well as each data point’s relative weight:

 

d3 to the rescue

d3 Zoomable Pack Layout

In our case, each top-level bubble is a topic, and each inner bubble is a word, with its relative size determined by the word’s frequency.

Since the d3 visualization takes JSON as input, in order to plug in our LDA output data we simply create a toJSON() method in TopicModel.java that outputs the data associated with the top 10 topics to a JSON file. The ‘name’ of each topic is simply the most probably word in the topic.

Now, when the LDA process (the main() method in TopicModel.java) is run for a twitter user, the code will create a corresponding JSON file in viz/json. The JSON structure:

{
 "name":"",
 "children":[
    {
     "name": {topic_1_name},
     "children":[
       {
        "name": {topic_1_word_1},
        "size": {topic_1_word_1_freq}
       },
       {
        "name": {topic_1_word_2},
        "size": {topic_1_word_2_freq}
       },
       {
        "name": {topic_1_word_3},
        "size": {topic_1_word_3_freq}
       },
....

Javascripting

Now, we make slight modifications to the javascript code embedded in the given d3 visualization. Our goal is to be able to toggle between results for different twitter users; we’d like to switch from investigating the @nytimes topics to getting a sense of what @KingJames tweets about.

To do so, we add a drop-down to index.html, such that each time a user is selected on the drop-down, their corresponding JSON is loaded by the show() function in viz.js. Hence we also change the show() function to reload the visualization each time it is called.

Making The Visualizations Visible

To run the code locally, navigate to the viz/static directory and start an HTTP server to serve the content, e.g.

cd {project_root}/viz/static
python -m SimpleHTTPServer

then navigate to http://localhost:8000/index.html to see the visualization.

By selecting nytimes, we see the following visualization which gives a sense of the topics:

@nytimes topics

Upon clicking the ‘gaza’ topic, we see the top words that comprise the topic:

'gaza' topic

I’ve also used Heroku to put an example of the finished visualization with data from 10 different twitter usernames here:

http://ldaviz.herokuapp.com/

Have fun exploring the various topics!