Poisson Variational Autoencoder

一言でいうと: Poisson Variational Autoencoderは、VAEを生物学的なspike countsに近づけるために、Poisson潜在変数とsoft thresholdによるreparameterized samplingを使う離散潜在VAEである。

Poisson Variational Autoencoder (P-VAE) は、PoissonVAE論文で提案された、潜在変数をPoisson-distributed spike countsとして扱うVAEである。 posterior rateを 、prior rateを と置き、encoderがpriorからの入力依存のずれ を出す。

P-VAEの要点は、Poisson processの待ち時間表現を使うreparameterization trickである。 Exponential分布からinter-event timesをサンプルし、累積到着時刻が単位時間内に入るかをsigmoidで近似する。

が小さいほどhard Poisson countsに近づく。 訓練時は のrelaxed countsを使い、テスト時は の整数countsを使う。

実装メモ

公式実装は hadivafaii/PoissonVAE である。 base/distributions.pyPoisson.rsample() では、Exponential samplesを累積し、soft indicatorを合計してspike countsを作る。

x = self._exp.rsample((self.n_exp,))
times = torch.cumsum(x, dim=0)
logits = (1 - times) / self.temp
indicator = fn(logits)
z = indicator.sum(0).to(dtype=self.rate.dtype)

P-VAE本体のmain/vae.pyでは、log_r + log_dr に対応する。 KL項は log_dr から直接計算する。

f = 1 + torch.exp(log_dr) * (log_dr - 1)
kl = torch.exp(log_r) * f

関連リンク