Discriminator Rejection Sampling


This note is a summary of the main idea behind Discriminator Rejection Sampling [1] in GANs.

1Rejection sampling

Suppose there is a true data distribution with density \(p_d (x)\) which is hard to sample from and an easy to sample from “proposal” distribution \(p_g (x)\). Assume there is an \(M < \infty\) such that:

\(\displaystyle Mp_g (x) > p_d (x), \quad \forall x \in \operatorname{supp} (p_d)\)

Then we can generate samples according to \(p_d\) by sampling from \(p_g\) and then accepting with probability:

\(\displaystyle \frac{p_d (x)}{Mp_g (x)}\) (1)

2Discriminator rejection sampling for GANs

In a GAN, we can consider the generator as the proposal distribution \(p_g\), and we are trying to train it to match the true distribution \(p_d\). The idea behind discriminator rejection sampling is to use rejection sampling to generate samples from \(p_d\), even if \(p_g\) doesn't exactly match \(p_d\). Unfortunately, Equation 1 requires the ratio of two densities which we don't have. In particular, even if we could compute \(p_g (x)\) it is rare to have the true density \(p_d (x)\). However, we can aim to estimate the ratio \(\frac{p_d}{p_g}\) directly using the discriminator.

In a GAN, the discriminator \(D\) takes an input \(x \sim p_{\operatorname{mix}}\) which is a balanced mix of the real data distribution \(p_d\) and the generator distribution \(p_g\). Its job is to output the probability that \(x\) is from \(p_d\), and is trained using the binary classification loss:

\(\displaystyle \max_D \mathbb{E}_{x \sim p_d} [\log D (x)] +\mathbb{E}_{x \sim p_g} [\log (1 - D (x))]\)

Let's assume we can train \(D\) to be the optimal discriminator. What does that look like? Since \(p_{\operatorname{mix}}\) is a balanced mix of \(p_d\) and \(p_g\), the density is:

\(\displaystyle p_{\operatorname{mix}} (x) = \frac{1}{2} p_d (x) + \frac{1}{2} p_g (x)\)

We want our discriminator to produce, for \(x \sim p_{\operatorname{mix}}\):

\begin{eqnarray*} D (x) & = & p (\operatorname{real}|x)\\ & = & \frac{p (x|\operatorname{real}) p (\operatorname{real})}{p_{\operatorname{mix}} (x)}\\ & = & \frac{p_d (x) \times \frac{1}{2}}{\frac{1}{2} p_d (x) + \frac{1}{2} p_g (x)}\\ & = & \frac{p_d (x)}{p_d (x) + p_g (x)}\\ & = & \frac{1}{1 + \frac{p_g (x)}{p_d (x)}} \hspace{3cm} \text{(2)} \end{eqnarray*}

Usually, the discriminator's output is the sigmoid of some logits \(\tilde{D}\):

\(\displaystyle D (x) = \sigma (\tilde{D} (x)) = \frac{1}{1 + e^{- \tilde{D} (x)}}\) (3)

Comparing equations 3 and 2 we see that:

\(\displaystyle \frac{p_d (x)}{p_g (x)} = \exp (\tilde{D} (x))\)

This gives us most of what we need to do rejection sampling (Equation 1). Now we want to find a value for \(M\):

\(\displaystyle M = \max_x \frac{p_d (x)}{p_g (x)} = \max_x \exp (\tilde{D} (x))\)

In practice we sample a bunch of \(x \sim p_g\) and take the max to get our estimate of \(M\). Then we can do rejection sampling by sampling \(x \sim p_g\) and then accepting with probability

\(\displaystyle \frac{\exp (\tilde{D} (x))}{M}\)


[1] Samaneh Azadi, Catherine Olsson, Trevor Darrell, Ian Goodfellow, and Augustus Odena. Discriminator rejection sampling. ArXiv preprint arXiv:1810.06758, 2018.