Poisson Variational Autoencoder
一言でいうと: Poisson Variational Autoencoderは、VAEを生物学的なspike countsに近づけるために、Poisson潜在変数とsoft thresholdによるreparameterized samplingを使う離散潜在VAEである。
Poisson Variational Autoencoder (P-VAE) は、PoissonVAE論文で提案された、潜在変数をPoisson-distributed spike countsとして扱うVAEである。 posterior rateを 、prior rateを と置き、encoderがpriorからの入力依存のずれ を出す。
P-VAEの要点は、Poisson processの待ち時間表現を使うreparameterization trickである。 Exponential分布からinter-event timesをサンプルし、累積到着時刻が単位時間内に入るかをsigmoidで近似する。
が小さいほどhard Poisson countsに近づく。 訓練時は のrelaxed countsを使い、テスト時は の整数countsを使う。
実装メモ
公式実装は hadivafaii/PoissonVAE である。
base/distributions.py の Poisson.rsample() では、Exponential samplesを累積し、soft indicatorを合計してspike countsを作る。
x = self._exp.rsample((self.n_exp,))
times = torch.cumsum(x, dim=0)
logits = (1 - times) / self.temp
indicator = fn(logits)
z = indicator.sum(0).to(dtype=self.rate.dtype)P-VAE本体のmain/vae.pyでは、log_r + log_dr が に対応する。
KL項は を log_dr から直接計算する。
f = 1 + torch.exp(log_dr) * (log_dr - 1)
kl = torch.exp(log_r) * f