Does your model get better at task T when you rank by estimated probability p(T) ?
To understand what to optimize in a ranking model
Summary
Often when building a recommender system, you have multiple labels / tasks / user actions that you want to optimize. A conventional approach to this is to use Multi-Task learning to estimate the probability of each task and then rank items by a weighted sum of these estimates.
PTAL this video by Andrej Karpathy for a lucid explanation of the utility and complexity of multi-task learning
In this post, we ask ourselves, if our ranking model is good at predicting the probability of a task T, would ranking items by p(T) actually increase the observation of T (under expectation)?
Conventional ranker
Will selecting items based on p(T) lead to a higher observation of task T?
At first it seems the answer to this should be “Of course! If your model is better than random at predicting the occurrence of a label (~task) T, then ranking items by the model predicted probability, p(T), should of course lead to a higher observation of T.” But let’s try to prove this.
As mentioned in Categorical Reparameterization with Gumbel-Softmax (ICLR 2017), the expected value of T of the top item when ranked by p(T) is
Here g_i is Gumbel noise, hence independent of p(T) and t.
In this Google Colab, I have verified this assumption and I do indeed find that the expected value of t under ranking is correlated to the weighted sum under Gumbel-Softmax.
Normally we train ranking models to minimize normalized binary cross entropy.
Here, t is the true label (either 0 or 1), p(T) is the predicted probability for the positive class.
The problem I am trying to reconcile is that even if we ignore the normalizing denominator in the Gumbel Softmax expression and the Gumbel noise, t_i is being multiplied by exp(log(pT)) in one case and log(p(T)) in the other. True, directionally they are aligned, but is minimizing cross-entropy loss really the optimal way to maximize the expected value of T?
Deriving update for Gumbel expectation with REINFORCE
Suppose we have a prediction model for of task T given item i and user u, and l is the logit output by the model.
As discussed in the section above, the probability that item i is selected by the recommender system for user u is roughly (ignoring Gumbel and tau):
Let’s try to maximize the expected value of task T with REINFORCE with a baseline (Ref: Chp 13.4 in Sutton-Barto). The REINFORCE update step involves the gradient of the log probability of the action (showing item i) multiplied by the reward. In the case of a task, the reward is often binary (success or failure).
Let's denote r as the binary reward for task T given item i and user u. The update step with REINFORCE and a baseline b would be:
A decent choice of baseline, b, could be the overall average rate of task T.
Now, breaking down the components:
Log Probability:
where Z is the normalization constant.
Gradient of Log Probability:
The gradient involves the derivative of the sigmoid function with respect to l and the gradient of l with respect to the model parameters \theta.
Update Step (Putting it all together):
The baseline b is subtracted to reduce variance. (Chapter 13.4 Sutton & Barto)
Ignoring normalization Z, and working out the derivative of sigmoid:
\(\theta \leftarrow \theta + \alpha \cdot \left(r - b\right) \cdot \nabla_\theta \sigma(l)\)
Hence:
Intuition
This update step encourages the model to adjust its parameters in a direction that increases the probability of showing item i to user u if task T is successful, while considering the baseline for variance reduction.
Deriving update for binary classification
Please see the derivation explained by ChatGPT here.
Conclusion and next steps
Maximizing expected occurrence of a task T is not the same as minimizing binary cross entropy of its prediction.
In a future post, we will explore adding the REINFORCE based reward maximization to the traditional binary cross entropy loss in the model and we will measure if this leads to higher occurrence of the label/task in test data.
Disclaimer: These are the personal opinions of the author(s). Any assumptions, opinions stated here are theirs and not representative of their current or any prior employer(s). Apart from publicly available information, any other information here is not claimed to refer to any company including ones the author(s) may have worked in or been associated with.