Poisson Variational Autoencoder

@inproceedings{NEURIPS2024_4f3cb957,
 author = {Vafaii, Hadi and Galor, Dekel and Yates, Jacob},
 booktitle = {Advances in Neural Information Processing Systems},
 doi = {10.52202/079017-1426},
 editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang},
 pages = {44871--44906},
 publisher = {Curran Associates, Inc.},
 title = {Poisson Variational Autoencoder},
 url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/4f3cb9576dc99d62b80726690453716f-Paper-Conference.pdf},
 volume = {37},
 year = {2024}
}

BibTeX source

一言でいうと: PoissonVAEは、VAEの潜在変数を連続GaussianではなくPoisson spike countsとして扱うために、Poisson reparameterization trickと予測符号化的なrate parameterizationを導入する離散潜在VAEである。

背景と目的 (Background & Objective)

標準的なVAEは、潜在変数にGaussianなどの連続分布を使う。 しかし、生物学的ニューロンはスパイク数として離散的に発火し、短い時間窓ではPoisson-likeな統計を示すことが多い。 この論文は、perception as inference、rate coding、predictive coding、sparse codingをVAEの枠組みで結び、潜在変数を離散spike countsとして扱うPoissonVAEを提案する。

提案手法 (Proposed Method)

P-VAEでは、潜在変数 をPoisson分布からサンプルされるspike countsとして扱う。 posteriorとpriorはそれぞれ次で定義される。

は学習可能なprior firing ratesであり、 はencoderが出す入力依存のdeviationである。 この残差的parameterizationにより、feedbackの期待 とfeedforwardのずれ が要素ごとの積で結びつく。

Poisson分布は離散分布なので、そのままでは通常のreparameterization trickを使えない。 P-VAEはhomogeneous Poisson processの待ち時間表現を使い、Exponential分布からinter-event timesをサンプルし、累積到着時刻が単位時間内に入るかをsoft indicatorで近似する。

はtemperatureであり、 でhard thresholdに近づく。 訓練時は のrelaxed Poissonを使い、テスト時は にして整数Poisson samplesを使う。

P-VAEの目的関数は次である。

KL項 はfiring ratesを罰するため、代謝コストやsparse codingのactivity penaltyに似た役割を持つ。 線形decoder とovercomplete latent space を仮定すると、P-VAEはamortized sparse codingとして解釈できる。

公式実装メモ

公式実装は hadivafaii/PoissonVAE で公開されている。 Poisson reparameterizationは base/distributions.pyPoisson.rsample() に実装されている。

x = self._exp.rsample((self.n_exp,))
times = torch.cumsum(x, dim=0)
logits = (1 - times) / self.temp
indicator = fn(logits)
z = indicator.sum(0).to(dtype=self.rate.dtype)

出典: base/distributions.py

P-VAE本体では、encoder出力 log_dr をprior log-rate log_r に加え、posterior rateを としてPoisson分布へ渡す。 これは論文の に対応する。

log_r = self.log_rate.expand(len(x), -1)
log_dr = self.encode(x)
dist = self.Dist(
    log_rate=softclamp_upper(log_r + log_dr, 5.0),
    indicator_approx=self.cfg.indicator_approx,
    n_exp=self.n_exp,
    temp=t,
)

出典: main/vae.py

KL項も閉形式で実装されている。 log_dr なので、1 + exp(log_dr) * (log_dr - 1) と同じ形である。

def loss_kl(self, log_dr):
    log_r = self.log_rate.expand(len(log_dr), -1)
    f = 1 + torch.exp(log_dr) * (log_dr - 1)
    kl = torch.exp(log_r) * f
    return kl

出典: main/vae.py

実験と評価 (Experiments & Evaluation)

比較対象は、PoissonVAE、Categorical VAE、Gaussian VAE、Laplace VAEである。 実験は、van_Hateren_DatasetCIFAR10から作ったCIFAR16×16、MNISTで行われる。

Poisson reparameterizationの評価では、線形decoderで計算できるexact gradientsと、Monte Carlo samplingやstraight-through estimatorを比較する。 P-VAEのMonte Carlo推定はexact gradientsに近い一方、straight-through estimatorは大きく性能を落とす。

ModelGradientvan Hateren linvan Hateren convCIFAR16×16 linCIFAR16×16 convMNIST linMNIST conv
P-VAEExact0.6 ± .50.7 ± .10.0 ± .10.5 ± .10.1 ± .10.9 ± .5
P-VAEMonte Carlo0.1 ± .17.3 ± .10.0 ± .09.1 ± .10.5 ± .68.1 ± .3
P-VAEStraight-through0.0 ± .110.5 ± .10.2 ± .012.5 ± .10.7 ± .411.8 ± .2
G-VAEExact0.1 ± .10.1 ± .10.0 ± .10.1 ± .10.1 ± .20.4 ± .1
G-VAEMonte Carlo0.0 ± .00.0 ± .00.0 ± .00.0 ± .00.1 ± .20.3 ± .1

Sparse codingとの比較では、P-VAEはGabor-likeなbasis elementsを学習し、LCAやISTAに似た辞書を得る。

posterior collapseの評価では、P-VAEはcontinuous VAEより多くのactive neuronsを保つ。

Modelvan Hateren linearvan Hateren convCIFAR16×16 linearCIFAR16×16 convMNIST linearMNIST conv
P-VAE0.984 ± .0110.819 ± .0410.999 ± .0020.928 ± .0450.537 ± .0080.426 ± .011
L-VAE0.188 ± .0000.222 ± .0030.193 ± .0030.230 ± .0000.027 ± .0000.034 ± .002
G-VAE0.218 ± .0030.246 ± .0000.105 ± .0080.246 ± .0000.027 ± .0000.031 ± .000

下流分類では、MNISTの5,000 validation samplesをtrain/testに分け、KNNで少数ラベル時の性能を比較する。 の潜在次元では、P-VAEは で0.815、G-VAEは0.705であり、G-VAEが同程度の精度に達するには が必要になる。

ModelKNN N=200KNN N=1,000KNN N=5,000Shattering dim.
P-VAE0.815 ± .0020.919 ± .0010.946 ± .0170.797 ± .009
C-VAE0.705 ± .0020.800 ± .0020.853 ± .0400.795 ± .006
L-VAE0.757 ± .0030.869 ± .0020.924 ± .0280.751 ± .008
G-VAE0.673 ± .0030.813 ± .0020.891 ± .0330.758 ± .007
G-VAE +relu0.694 ± .0030.817 ± .0030.877 ± .0450.762 ± .007
G-VAE +exp0.642 ± .0030.784 ± .0020.863 ± .0320.737 ± .008

貢献と限界点 (Contributions & Limitations)

主な貢献は、Poisson-distributed latent variablesをVAEに組み込み、離散spike countsを微分可能に扱うreparameterization trickを示した点である。 また、KL項が自然にfiring-rate penaltyになり、線形decoderではamortized sparse codingと対応することを示している。

限界は、Poisson spike countsが短時間窓の神経発火の近似であり、皮質活動の長時間窓ではPoissonから外れる場合がある点である。 さらに、現在のencoderではLCA/ISTAに対するamortization gapが残り、P-VAEのdictionaryを使ってLCA inferenceを行うとMSEが改善する。 このため、P-VAEの潜在表現と辞書は有望だが、推論ネットワークの表現力や反復推論との統合は未解決である。

関連リンク