GSoC 2022 Retrospective
GSoC 2022 Retrospective
During the summer of 2022, I participated in the GSoC (Google Summer of Code). I was fortunate enough to be mentored by Kevin P. Murphy and Scott Linderman. My main contributions have been for the probml/ssm-jax
and probml/pyprobml
repos.
Update (2022.10.17): the
ssm-jax
library has been renameddynamax
. Some of the links below may become broken in the upcoming days, in which case replacingssm-jax
withdynamax
should do the trick.
Main Contributions
The majority of my time has been spent building the inference portion of the ssm-jax
library.
Starting from Kalman filters (LGSSM), I climbed the tower of abstraction, eventually arriving at the most general formulation of Gaussian filters, the conditional-moments Gaussian filters (CMGFs).
Kalman filters are closed-form solutions for state-space models (SSMs) in which all distributions involved are linear Gaussian. While useful, most SSMs in the real world do not have linear-Gaussian forms for the state dynamics and emission models.
I learned about and implemented extensions of Kalman filters for non-linear dynamics and emission models, namely the extended Kalman filter (EKF), the unscented Kalman filter (UKF), and Gauss-Hermite Kalman filter (GHKF). I created a nice toy example demonstrating the differences in the scopes (and accuracies) of their Gaussian approximations for non-linear transformations, as shown below.
![]() |
![]() |
![]() |
Next, I implemented the General Gaussian filter (GGF) which unifies the non-linear extensions of KFs into a filter that makes Gaussian approximations to the joint distributions using moment-matching. In addition, I implemented its iterated posterior-linearization extension, which computes expectations with respect to the previous iteration’s posterior distribution. However, GGF is still limited in its application by its restriction that the emission model has to be Gaussian.
I reached the peak of the tower of abstraction for SSMs by implementing the conditional-moments Gaussian filter (CMGF), which relaxes the assumption of Gaussian emission. I created a series of demos demonstrating the utility of CMGF, starting with the binary online logistical regression. As shown below, the CMGF-inferred weights rapidly converge to their MAP estimates.
![]() |
![]() |
![]() |
Next, I demonstrated that CMGF performs almost identically (in terms of 10-fold cross-validation average accuracies) to many-pass SGD and significantly better than single-pass SGD when applied to multinomial logistic regression.
Also, I demonstrated that CMGF is able to accurately infer latent states based on Poisson likelihood, as shown below.
![]() |
![]() |
Finally, I demonstrated that CMGF can be used to train MLP-classifiers in a single pass. As shown in the video embedded below, CMGF is able to train an MLP (with two hidden layers) to accurately perform binary classification given a relatively complex training dataset.
Textbook Section Co-Author
I was also lucky enough to co-author a section of Kevin’s upcoming sequel to his extremely popular Bayesian machine learning textbook: Probabilistic Machine Learning: Advanced Topics. The section, 8.7: General Gaussian Filtering, covers Gaussian moment-matching, statistical linear regression, iterated posterior linearization, and conditional moments Gaussian filter.
Conclusion
As a beginner open-source contributor, I found GSoC 2022 to be a thoroughly fulfilling and enjoyable learning experience. At the beginning of the summer, I barely understood how Kalman filters worked, and it was incredibly satisfying to gradually discover the subtle differences among the myriad extensions of Kalman filters. Witnessing my implementation of the most general version of them all, the CMGF, perform well in various challenging demos has been very exciting.
I feel extremely grateful to have been a member of the ssm-jax
team, and I learned so much from everyone that I’ve worked with. Our incredible mentors, Kevin Murphy and Scott Linderman, fostered a warm and inclusive environment that always encouraged challenging myself without, despite my relative lack of knowlege and experience, ever feeling overwhelming. I greatly look forward to continuing our work (and friendship) with the rest of the team in the future.
List of PRs
Repo | Issue # | PR # | Description |
---|---|---|---|
probml-notebooks |
698 | 54 | Convert LeNet1989 to JAX notebook. |
pyprobml |
698 | 715 | Convert LeNet1989 to JAX .py file. |
probml-notebooks |
708 | 61 | Translate Random Priors Ensemble demo to JAX. |
probml-notebooks |
708 | 69 | Optimize and improve Random Priors Ensemble demo. |
JSL |
736 | 35 | Implement fixed lag smoothing for HMM. |
JSL |
N/A | 57 | Fix scipy.special.logit to jnp.log . |
ssm-jax |
8 | 26 | Reimplement hmm_posterior_sample() . |
ssm-jax |
8 | 28 | Implement test_hmm_posterior_sample() to compare with full joint probs. |
ssm-jax |
9 | 29 | Reimplement hmm_fixed_lag_smoother() . |
ssm-jax |
40 | 42 | Implement EKF. |
ssm-jax |
32 | 77 | Implement EKF-MLP training demo. |
ssm-jax |
79 | 88 | Create ekf_spiral.py demo. |
ssm-jax |
63 | 111 | Implement UKF. |
ssm-jax |
113 | 115 | Create ukf_spiral.py demo. |
pyprobml |
1017 | 1018 | Add ekf_vs_ukf.ipynb demo notebook. |
ssm-jax |
118 | 122 | Implement GGSSM/GGF. |
pyprobml |
1082 | 1083 | Add GHKF comparison to ekf_vs_ukf.ipynb . |
ssm-jax |
139 | 135 | Implement iterated EKF and iterated EKS. |
ssm-jax |
144 | 154 | Rename classes to more descriptive names. |
ssm-jax |
140 | 156 | Implement CMGF. |
ssm-jax |
158 | 159 | Create CMGF online logistic regression demo. |
pyprobml |
1104 | 1105 | Fix ADF logistic regression conversion issue. |
ssm-jax |
178 | 193 | Create CMGF multinomial logistic regression demo. |
ssm-jax |
190 | 200 | Create CMGF Poisson likelihood inference demo. |
ssm-jax |
191 | 201 | Create CMGF-trained MLP-classifier demo. |