GSoC 2022 Retrospective

GSoC 2022 Retrospective

During the summer of 2022, I participated in the GSoC (Google Summer of Code). I was fortunate enough to be mentored by Kevin P. Murphy and Scott Linderman. My main contributions have been for the probml/ssm-jax and probml/pyprobml repos.

Update (2022.10.17): the ssm-jax library has been renamed dynamax. Some of the links below may become broken in the upcoming days, in which case replacing ssm-jax with dynamax should do the trick.

Main Contributions

The majority of my time has been spent building the inference portion of the ssm-jax library.

Starting from Kalman filters (LGSSM), I climbed the tower of abstraction, eventually arriving at the most general formulation of Gaussian filters, the conditional-moments Gaussian filters (CMGFs).

Kalman filters are closed-form solutions for state-space models (SSMs) in which all distributions involved are linear Gaussian. While useful, most SSMs in the real world do not have linear-Gaussian forms for the state dynamics and emission models.

I learned about and implemented extensions of Kalman filters for non-linear dynamics and emission models, namely the extended Kalman filter (EKF), the unscented Kalman filter (UKF), and Gauss-Hermite Kalman filter (GHKF). I created a nice toy example demonstrating the differences in the scopes (and accuracies) of their Gaussian approximations for non-linear transformations, as shown below.

EKF-approximated moments.
UKF-approximated moments.
GHKF-approximated moments.

Next, I implemented the General Gaussian filter (GGF) which unifies the non-linear extensions of KFs into a filter that makes Gaussian approximations to the joint distributions using moment-matching. In addition, I implemented its iterated posterior-linearization extension, which computes expectations with respect to the previous iteration’s posterior distribution. However, GGF is still limited in its application by its restriction that the emission model has to be Gaussian.

I reached the peak of the tower of abstraction for SSMs by implementing the conditional-moments Gaussian filter (CMGF), which relaxes the assumption of Gaussian emission. I created a series of demos demonstrating the utility of CMGF, starting with the binary online logistical regression. As shown below, the CMGF-inferred weights rapidly converge to their MAP estimates.

Next, I demonstrated that CMGF performs almost identically (in terms of 10-fold cross-validation average accuracies) to many-pass SGD and significantly better than single-pass SGD when applied to multinomial logistic regression.

Also, I demonstrated that CMGF is able to accurately infer latent states based on Poisson likelihood, as shown below.

Finally, I demonstrated that CMGF can be used to train MLP-classifiers in a single pass. As shown in the video embedded below, CMGF is able to train an MLP (with two hidden layers) to accurately perform binary classification given a relatively complex training dataset.

Textbook Section Co-Author

I was also lucky enough to co-author a section of Kevin’s upcoming sequel to his extremely popular Bayesian machine learning textbook: Probabilistic Machine Learning: Advanced Topics. The section, 8.7: General Gaussian Filtering, covers Gaussian moment-matching, statistical linear regression, iterated posterior linearization, and conditional moments Gaussian filter.

Conclusion

As a beginner open-source contributor, I found GSoC 2022 to be a thoroughly fulfilling and enjoyable learning experience. At the beginning of the summer, I barely understood how Kalman filters worked, and it was incredibly satisfying to gradually discover the subtle differences among the myriad extensions of Kalman filters. Witnessing my implementation of the most general version of them all, the CMGF, perform well in various challenging demos has been very exciting.

I feel extremely grateful to have been a member of the ssm-jax team, and I learned so much from everyone that I’ve worked with. Our incredible mentors, Kevin Murphy and Scott Linderman, fostered a warm and inclusive environment that always encouraged challenging myself without, despite my relative lack of knowlege and experience, ever feeling overwhelming. I greatly look forward to continuing our work (and friendship) with the rest of the team in the future.

List of PRs

Repo Issue # PR # Description
probml-notebooks 698 54 Convert LeNet1989 to JAX notebook.
pyprobml 698 715 Convert LeNet1989 to JAX .py file.
probml-notebooks 708 61 Translate Random Priors Ensemble demo to JAX.
probml-notebooks 708 69 Optimize and improve Random Priors Ensemble demo.
JSL 736 35 Implement fixed lag smoothing for HMM.
JSL N/A 57 Fix scipy.special.logit to jnp.log.
ssm-jax 8 26 Reimplement hmm_posterior_sample().
ssm-jax 8 28 Implement test_hmm_posterior_sample() to compare with full joint probs.
ssm-jax 9 29 Reimplement hmm_fixed_lag_smoother().
ssm-jax 40 42 Implement EKF.
ssm-jax 32 77 Implement EKF-MLP training demo.
ssm-jax 79 88 Create ekf_spiral.py demo.
ssm-jax 63 111 Implement UKF.
ssm-jax 113 115 Create ukf_spiral.py demo.
pyprobml 1017 1018 Add ekf_vs_ukf.ipynb demo notebook.
ssm-jax 118 122 Implement GGSSM/GGF.
pyprobml 1082 1083 Add GHKF comparison to ekf_vs_ukf.ipynb.
ssm-jax 139 135 Implement iterated EKF and iterated EKS.
ssm-jax 144 154 Rename classes to more descriptive names.
ssm-jax 140 156 Implement CMGF.
ssm-jax 158 159 Create CMGF online logistic regression demo.
pyprobml 1104 1105 Fix ADF logistic regression conversion issue.
ssm-jax 178 193 Create CMGF multinomial logistic regression demo.
ssm-jax 190 200 Create CMGF Poisson likelihood inference demo.
ssm-jax 191 201 Create CMGF-trained MLP-classifier demo.