The old’ load balancing loss. Instead of training a router with explicitly labeled data for each expert, a load balancing + load concentration loss induces the modularity in data.
Insight: we want to maximize the mutual information between tokens and modules. For the router \(m \sim g\qty(\cdot \mid x)\) (“which module \(m\) should we assign, given token \(x\)”), we write:
\begin{equation} \ell_{MI} = \underbrace{\sum_{m=1}^{N} p\qty(m) \log p\qty(m)}_{-H\qty(m)} - \frac{1}{|X|} \sum_{x \in X}^{} \underbrace{\sum_{m=1}^{N} g\qty(m|x) \log g\qty(m|x)}_{H\qty(m|x)} \end{equation}
that is, we want to maximize the entropy of the marginal distribution (“without knowing what tokens, we shouldn’t be opinionated about which module”) while minimizing the entropy of the conditional distribution (“knowing which tokens, we should know which modules.”)
At fine-tuning time (i.e. if we do have task specific data), we can directly minimize the entropy:
\begin{equation} \ell_{\text{concentration}} = - \sum_{m=1}^{N} p\qty(m) \log p\qty(m) \end{equation}
since we want to use one module for each specialized fine tuning mix.
