# Peering into the Black Box : Visualizing LambdaMART

In the last post, I gave a broad overview of the Learning to Rank domain of machine learning that has applications in web search, machine translation, and question-answering systems. In this post, we’ll look at a state of the art model used in Learning to Rank called LambdaMART. We’ll take a look at some math underlying LambdaMART, then focus on developing ways to visualize the model.

LambdaMART produces a tree ensemble model, a class of models traditionally viewed as ‘black boxes’ since they take into account predictions of 10’s or 100’s of underlying trees. While viewing a single decision tree is intuitive and interpretable, viewing 100 at once would be overwhelming and uninformative:

Single Tree

Ensemble of Trees

By moving from a single tree to an ensemble of trees, we tradeoff interpretability for performance.

In this post, we’ll take a look at the trees produced by the LambdaMART training process, and develop a visualization to view the ensemble of trees as a single collective unit. By doing so, we’ll gain back some of the intuitive interpretability that make tree models appealing. Through the visualization, we’ll peer into the black box model and gain some sense of the factors that the model uses to make predictions.

We use Java and the JForests library to train LambdaMART and parse its output, and d3.js for visualization. To get a sneak peek, the final visualization is here.

LambdaMART Overview

Let’s take a look at some details of LambdaMART. For the full story, check out this paper from Microsoft Research.

At a high level, LambdaMART is an algorithm that uses gradient boosting to directly optimize Learning to Rank specific cost functions such as NDCG. To understand LambdaMART we’ll look at two aspects: Lambda and MART.

MART

LambdaMART is a specific instance of Gradient Boosted Regression Trees, also referred to as Multiple Additive Regression Trees (MART). Gradient Boosting is a technique for forming a model that is a weighted combination of an ensemble of “weak learners”. In our case, each “weak learner” is a decision tree.

Our goal is to find a function $f(x)$ that minimizes the expected loss $L$:

$\hat{f}(x) = \arg\min_{f(x)}E[L(y, f(x))|x]$

for feature vectors $x$ and labels $y$.

To do so we first view $\hat{f}(x)$ as a sum of weak learners $f_m$:

$\hat{f}(x)=\sum_{m=1}^M f_{m}(x)$

Since we are dealing with decision trees, evaluating $f_{m}(x)$ corresponds to passing $x$ down the tree $f_{m}$ until it reaches a leaf node $l$. The predicted value is then equal to a parameter $\gamma_l$.

Decision trees consist of two types of parameters: region assignments $R_i$ that assign a training example $x_i$ to a leaf node $l$, and leaf outputs $\gamma_l$ that represent the tree’s output for all examples assigned to region $R_l$. Hence we can write each weak learner $f_{m}(x)$ as a parameterized tree $h_{m}(x;R_{m}, \gamma_{m})$, giving:

$\hat{f}(x)=\sum_{m=1}^M f_{m}(x)=\sum_{m=1}^M h_{m}(x; R_{m}, \gamma_{m})$

Gradient Boosting finds each tree in a stepwise fashion. We start with an initial tree $f_1$, then step to another tree $f_2$, and so on:

$\hat{f}_{1}(x) = f_{1}$

$\hat{f}_{2}(x) = f_{1} + f_{2}$

$\hat{f}_{M}(x) = \sum_{m=1}^{M}f_{m}$

How do we determine the steps? At each step, we’d like the model to change such that the loss decreases as much as possible. The locally optimal decrease corresponds to the gradient of the loss with respect to the current model $m$:

$g_{im}=\frac{\partial L[f(x_i), y_i]}{\partial f(x_i)}$

Hence we define the step as $f_m=-\rho_mg_m$. This gives us insight into how Gradient Boosting solves the minimization – the algorithm is performing Gradient Descent in function space.

Now, we know that we want to take a gradient step, but still haven’t said how that translates to finding a tree. To do so, we build a tree that models the gradient; by adding such a tree to the ensemble, we effectively take a gradient step from the previous model. To fit the tree, we use squared error loss, giving the minimization problem:

$argmin_{R, \gamma}\sum_{i=1}^N(-g_{im}-F(x_i;R, \gamma))^2$

Where $F(x_i;R, \gamma)$ denotes a regression tree with parameters $R, \gamma$.

Let’s take a step back. We’ve just set up a framework for finding an ensemble of trees that, when added together, minimizes a loss function. It’s a “framework” in the sense that we can use it if we merely supply gradients of the loss function at each training point. This is where the “Lambda” part of LambdaMART comes into play.

For further reading, Chapter 10 of Elements of Statistical Learning provides a great and thorough overview. also provides a good intro of Gradient Boosting.

Lambda

In ranking, the loss function that we’ll most likely care about optimizing is probably either NDCG, MAP, or MRR. Unfortunately, these loss functions aren’t differentiable at all points, so we can’t use them directly in our Gradient Boosting “framework”, since it’s unclear how we can provide the gradients at each training point.

To address this, LambdaMART uses an idea from a model called LambdaRank. For each training point pair $i,j$, we compute a value $\lambda_{i,j}$ that acts as the gradient we need. Intuitively, $\lambda_{i,j}$ can be thought of as a force that moves documents up and down the ranked list. For instance, if document $i$ is ranked lower than document $j$, then document $i$ will receive a push of size $|\lambda_{i,j}|$ downwards in the ranked list, and $j$ will be pushed upwards the same amount. The LambdaMART paper reports that empirically, the $\lambda$‘s have been used successfully to optimize NDCG, MAP, and MRR.

Hence in LambdaMART we use the Gradient Boosting framework with the $\lambda_i$‘s acting as the gradients $g_i$. This leads to update rules for the leaf values $\gamma_l$ that depend on the $\lambda$ values of the training instances that fall in leaf $l$.

We’ve observed that LambdaMART trains an ensemble of trees sequentially. What does this sequence of trees look like? Let’s look at an ensemble of 19 trees trained with LambdaMART on the OHSUMED dataset:

At each inner node, the small number denotes the feature label, and the larger number denotes the threshold. The number at each leaf node denotes the leaf output.

We can see that each tree is somewhat similar to the previous tree. This makes sense given that each tree $t_i$ is dependent on the model state $F_{i-1}$ since it is fit to the gradient with respect to that model’s gradient.

The step-by-step animation gives us insight into what is happening during training. It’s a great way to visualize and gain an intuition of the gradient boosting process – we see consecutive trees being produced that depend on the previous tree. However, we are still viewing the LambdaMART model as a group of separate trees; we still need to develop a way of viewing the final model cohesively.

Does the sequence animation tell us anything that could help with developing a single, cohesive model view? We can start with the observation that often, the trees look “roughly similar”.

Sometimes two trees will have the same feature at a given node position; for instance the root favors features 7, 9, and 11. Trees may not even change from step-to-step or may share a similar overall structure, such as trees #3 and #4 – notice how the numbers change, but the structure remains the same. Finally, all of the trees only draw from a subset of the total feature set; for instance feature 4 is never used by any of the trees.

Grouping Similar Nodes

Given these observations, we can think about how to build a unified visualization that takes advantage of reused features and similar structure between trees. Let’s consider the root node. We can count the frequency of each feature that occurs as the root node across the entire ensemble, leading to a single node that shows that feature 7 was used as the feature for the root node for 3 trees, feature 9 for 5 trees, and so on:

At each tree level h, there are $2^h$ possible node positions. We can collect the feature frequencies for each of these positions across the ensemble. If a position doesn’t appear in a tree, we mark it as ‘DNE’, and if a position is a leaf, we mark it as ‘Leaf’, such as position (3, 6):

Heatmap Tree

At each node, we have a list of feature (or DNE or Leaf) counts. Heatmaps provide a natural way of visualizing this count information. We can make each node of the ordered-pair-tree into a heatmap, giving rise to the “Heatmap Tree”:

The Heatmap Tree

To view the interactive d3.js visualization, see this link.

The Heatmap Tree shows the entire ensemble as a single unit. Each number is a feature, and the color scale tells us how many trees use that feature at the given node:

By looking at the Heatmap Tree we can get a sense of which features the tree uses when it classifies instances. The ensemble uses features 7, 9, 11, 25, and 29 to perform the initial split, with 11 being used the most often. Further down the tree, we see features 7 and 9 again, along with common features such as 33 and 37. We can easily see that most of the trees in the ensemble are between 5 and 7 levels deep, and that while there are 46 total features, at a given node location the most variation we see is 9 features.

The Heatmap Tree can bring together hundreds of trees, such as this visualization of a 325 tree ensemble:

Tuning Parameters

The Heatmap Tree also lets us see how tuning parameters affect the final model. For instance, as we decrease the learning rate from 0.10 to 0.001, we see the ensemble size fluctuate:

Learning Rate 0.10

Learning Rate 0.01

Learning Rate 0.001

Notice how in the 0.01 case, the Heatmap Tree concisely summarized a 111-tree ensemble.

When we use feature subsampling, we see the number of features at each node increase (in general). Each tree has a different limited subset of features to choose from, leading to a spread in the overall distribution of chosen features:

Feature subsampling 0.6

Feature subsampling and training data subsampling makes this ensemble more crowded:

Feature subsampling 0.6, Data subsampling 0.6

Note that these parameter trends do not necessarily generalize. However, the Heatmap Tree captures all of the trees and features in a single structure, and gives us insight into the structural results of our parameter choices.

Limitations

The Heatmap Tree, unfortunately, has its limits. With wide variation of features and many leaves, the tree becomes crowded:

Since the number of possible node locations at a given level increases exponentially with height, the tree also suffers when trying to visualize deep trees.

Expansions

Another nice aspect of decision trees is that we can visualize how a test instance gets classified; we simply show the path it takes from root to leaf.

How could we visualize the classification process in an ensemble, via the Heatmap Tree or otherwise? With the Heatmap Tree, we would need to be able to simultaneously visualize 10’s or 100’s of paths, since there would be an individual path for every tree in the ensemble. One idea is to have weighted edges on the Heatmap Tree; an edge would become thicker each time the edge is used when classifying an instance.

Another next step is to test the generalizability of this visualization; could it work for any gradient boosting model? What would a Heatmap Tree of a random forest look like?

Conclusion

We’ve taken a close look at LambdaMART and gradient boosting. We’ve devised a way to capture the complexity of gradient boosted tree ensembles in a cohesive way. In doing so we bought back some of the interpretability that we lost by moving from a single tree to an ensemble, gained insight into the training process, and made the black box LambdaMART model a bit more transparent.

To see an interactive d3 Heatmap Tree, visit this link.