Abstract

In this post, we will go through the PyTorch code to compute the Kullback-Leibler divergence between two full-rank Gaussians. The code might be useful if one considers using full-rank Gaussian as variational posterior while training a variational autoencoder.

KL divergence between two full-rank Gaussians

It’s common practice to parameterize the covariance matrix $\boldsymbol\Sigma$ of a $d$-dimensional full-rank Gaussian using a $D$-dimensional vector of nonzero elements of $\mathbf L$, where $D = d(1+d)/2$ and $\boldsymbol\Sigma = \mathbf L \mathbf L^\top$ is the Cholesky decomposition. So we will assume it here. Note that the diagonal of $\mathbf L$ must be positive so that $\boldsymbol\Sigma$ is positive definite. We will enforce this by taking the exponential on the diagonal elements (e.g. the first $d$ elements of our parameterization).

Let the two Gaussians be $p(\boldsymbol x) = \mathcal N(\boldsymbol x \mid \boldsymbol\mu_1, \boldsymbol\Sigma_1)$ and $q(\boldsymbol x) = \mathcal N(\boldsymbol x \mid \boldsymbol\mu_2, \boldsymbol\Sigma_2)$. Per The Book of Statistical Proofs, the KL divergence between them is:

\[D_\mathrm{KL}(p \parallel q) = \frac{1}{2}\left((\boldsymbol\mu_2 - \boldsymbol\mu_1)^\top \boldsymbol\Sigma_2^{-1} (\boldsymbol\mu_2 - \boldsymbol\mu_1) + \operatorname{tr}(\boldsymbol\Sigma_2^{-1} \boldsymbol\Sigma_1) - \log \frac{\det \boldsymbol\Sigma_1}{\det \boldsymbol\Sigma_2} - d\right)\,.\]

Plugging in our parameterization of the covariance matrices:

\[\begin{aligned} D_\mathrm{KL}(p \parallel q) &= \frac{1}{2}\left((\boldsymbol\mu_2 - \boldsymbol\mu_1)^\top \mathbf L_2^{-\top} \mathbf L_2^{-1} (\boldsymbol\mu_2 - \boldsymbol\mu_1) + \operatorname{tr}((\mathbf L_2 \mathbf L_2^\top)^{-1} (\mathbf L_1 \mathbf L_1^\top)) - \log \frac{\det(\mathbf L_1 \mathbf L_1^\top)}{\det(\mathbf L_2 \mathbf L_2^\top)} - d\right)\\ &= \frac{1}{2}\left((\mathbf L_2^{-1} (\boldsymbol\mu_2 - \boldsymbol\mu_1))^\top (\mathbf L_2^{-1} (\boldsymbol\mu_2 - \boldsymbol\mu_1)) + \operatorname{tr}((\mathbf L_2^{-1} \mathbf L_1)^\top (\mathbf L_2^{-1} \mathbf L_1)) - 2\log\frac{\det\mathbf L_1}{\det\mathbf L_2} - d\right)\,.\\ \end{aligned}\]

We have used the following facts:

  • the cyclic property of trace;
  • $\det \mathbf A = \det \mathbf A^\top$;
  • $\log\det (\mathbf A \mathbf B) = \log\det(\mathbf A) + \log\det(\mathbf B)$.

It follows that:

\[D_\mathrm{KL}(p \parallel q) = \frac{1}{2}\big(\boldsymbol y^\top \boldsymbol y + \|\mathbf M\|_F^2 - 2 (\operatorname{tr}(\log \mathbf L_1) - \operatorname{tr}(\log \mathbf L_2)) - d\big)\,,\]

where $\mathbf L_2 \boldsymbol y = \boldsymbol\mu_2 - \boldsymbol\mu_1$, and $\mathbf L_2 \mathbf M = \mathbf L_1$.

We have denoted:

  • $\|\cdot\|_F$ as the Frobenius norm of a matrix;
  • $\log \mathbf A$ as the elementwise logarithm of $\mathbf A$.

We have used the following facts:

  • $\operatorname{tr}(\mathbf A^\top \mathbf A) = \|\mathbf A\|_F^2$;
  • $\log\det \mathbf L = \operatorname{tr}(\log \mathbf L)$ when $\mathbf L$ is a lower triangular matrix.

Code

import torch
from torch import distributions as D


def form_cholesky_tril_from_elements(d, scale_tril_elems):
    """
    Form the Cholesky lower triangular matrix from its elements.

    Args:
        d (int): The number of rows/columns in the square matrix.
        scale_tril_elems (torch.Tensor): The Cholesky lower triangular
            elements, of shape (batch_size, (1 + d) * d // 2).

    Returns:
        torch.Tensor: A tensor of shape (batch_size, d, d).
    """
    batch_size = scale_tril_elems.size(0)
    device = scale_tril_elems.device
    i, j = torch.tril_indices(d, d, device=device)
    l_mat = torch.zeros(batch_size, d, d, device=device)
    l_mat[:, i, j] = scale_tril_elems
    l_mat_diag = l_mat.diagonal(dim1=1, dim2=2)
    l_mat_diag.copy_(l_mat_diag.exp())
    return l_mat


d = 3
batch_size = 5


def groundtruth(mean1, scale_tril1, mean2, scale_tril2):
    p = D.MultivariateNormal(loc=mean1, scale_tril=scale_tril1)
    q = D.MultivariateNormal(loc=mean2, scale_tril=scale_tril2)
    return D.kl_divergence(p, q)


def ours(mean1, scale_tril1, mean2, scale_tril2):
    y = torch.linalg.solve_triangular(
        scale_tril2, (mean2 - mean1).unsqueeze(-1), upper=False).squeeze(-1)
    y2 = y.square().sum(-1)
    M = torch.linalg.solve_triangular(scale_tril2, scale_tril1, upper=False)
    M2 = M.square().flatten(-2, -1).sum(-1)
    return 0.5 * (y2 + M2 - 2 * (
        scale_tril1.diagonal(dim1=-2, dim2=-1).log().sum(-1)
        - scale_tril2.diagonal(dim1=-2, dim2=-1).log().sum(-1)) - d)


# Randomize p and q's parameterization.
mean1 = torch.randn(batch_size, d)
mean2 = torch.randn(batch_size, d)
scale_tril1 = form_cholesky_tril_from_elements(
    d, torch.randn(batch_size, (1 + d) * d // 2))
scale_tril2 = form_cholesky_tril_from_elements(
    d, torch.randn(batch_size, (1 + d) * d // 2))

# Assert the correctness.
assert torch.allclose(groundtruth(mean1, scale_tril1, mean2, scale_tril2),
                      ours(mean1, scale_tril1, mean2, scale_tril2))

Profile our implementation:

%timeit groundtruth(mean1, scale_tril1, mean2, scale_tril2) (baseline):

164 μs ± 178 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%timeit ours(mean1, scale_tril1, mean2, scale_tril2) (our implementation):

46.2 μs ± 71.6 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)