First, I think this is really cool. Its great to see novel generative architectures.
Here are my thoughts on the statistics behind this. First, let D be the data sample. Start with
the expectation of -Log[P(D)] (standard generative model objective).
We then condition on the model output at step N.
- Expectation of Log[Sum over model outputs at step N{P(D | model output at step N) * P(model output at step N)}]
Now use Jensen's inequality to transform this to
<= - expectation of Sum over model outputs at step N{Log[P(D | model output at step N) * P(model output at step N)]}
Apply Log product to sum rule
= - expectation of Sum over model outputs at step N {Log(P(D | model output at step N)) + Log(P(model output at step N))}
If we assume there is some normally distributed noise we can transform the first term into the standard L2 objective.
= - expectation of Sum over model outputs at step N {L2 distance(D, model output at step N) + Log(P(model output at step N))}
Apply linearity of expectation
= Sum over model outputs at step N [expectation of{L2 distance(D, model output at step N)}]
- Sum over model outputs at step N [expectation of {Log(P(model output at step N))}]
and the summations can be replaced with sampling
= expectation of {L2 distance(D model output at step N)}
- expectation of {Log(P(model output at step N))}]
Now, focusing on just the - expectation of Log(P(sampled model output at step N)) term.
= - expectation of Log[P(model output at step N)]
and condition on the prior step to get
= - expectation of Log[Sum over possible samples at N-1 of
(P(sample output at step N| sample at step N - 1) * P(sample at step N - 1))]
Now, for each P(sample at step T | sample at step T - 1) this is approximately equal to 1/K. This is enforced by the Split-and-Prune operations which try to keep each output sampled at roughly equal frequencies.
So this is approximately equal to
≃ - expectation of Log[Sum over possible samples at N-1 of (1/K * P(possible sample at step N - 1))]
And you get an upper bound by only considering the actual sample.
<= -Log[1/K * expectation of P(actual sample at step N - 1))]
And applying some log rules you get
= Log(K) - expectation of Log[P(sample at step N - 1)]
Now, you have (approximately) expectation of -Log[P(sample at step N)] <= Log(K) - expectation of Log[P(sample at step N - 1)]. You can repeatedly apply this transformation until step 0 to get
(approximately) expectation of -Log[P(sample at step N)] <= N * Log(K) - expectation of Log[P(sample at step 0)]
and WLOG assume that expectation of P(sample at step 0) is 1 to get
expectation of -Log[P(sample at step N)] <= N * Log(K)
Plugging this back into the main objective, we get (assuming the Split-and-Prune is perfect)
expectation of -Log[P(D)] <= expectation of {L2 distance(D, sampled model output at step N)} + N * Log(K)
And this makes sense. You are providing the model with an additional Log_2(K) bits of information every time you perform an argmin operation, so in total you have provided the model with N * Log_2(K) bits for information. However, this is constant so you can ignore it from the gradient based optimizer.
So, given this analysis my conclusions are:
1) The Split-and-Merge is a load-bearing component of the architecture with regards to its statistical correctness. I'm not entirely sure about how this fits with the gradient based optimizer. Is it working with the gradient based optimizer, fighting the gradient based optimizer, or somewhere in the middle? I think the answer to this question will strongly affect this approaches scalability. This will also need a more in-depth analysis to study how deviations from perfect splitting affect the upper bound on loss.
2) With regards to statistical correctness, the L2 distance between the output at step N and D is the only one that is important. The L2 losses in the middle layers can be considered auxiliary losses. Maybe the final L2 loss / L2 losses deeper in the model should be weighted more heavily? In final evaluation the intermediate L2 losses can be ignored.
3) Future possibilities could include some sort of RL to determine the number of samples K and depth N on a dynamic basis. Even a split with K=2 increases NLL loss by Log_2(2) = 1. For many samples after a given depth the increase in loss due to the additional information outweighs the decrease in L2 loss. This also points to another difficulty, it is hard to give fractional information in this Discrete Distribution Network architecture. In contrast, diffusion models and autoregressive models can handle fractional bits. This could be another point of future development.