Dynamax for Multiple Time Series problems

Hey folks,

Just wondering if any of you has worked with multiple time series problems using the Dynamax library. I’m thinking on a problem where there are 100 time series with some trend and seasonality components (weekly, yearly, holidays) and you want to model them using a SSM while accounting for any correlation between them. The closest example I’ve seen is in the Time Series Forecasting notebook in the documentation of the numpyro library, but this is for a univariate time series problem. Any ideas / resources / examples are more than welcome!

1 Like

I havent used dynamax specifically but here’s an example with Tensorflow. If all the time series are independent it’d be quite nice to use a library like Dynamax as you could parallelize quite easily

https://bayesiancomputationbook.com/markdown/chp_06.html#state-space-models

I’m curious, what are you working on?

Hey @RavinKumar! I’m working on something similar to the M5 Forecasting competition where I have hierarchical sales data at the item-store-region-date level. There are thousands of items, hundreds of stores and tens of regions. Forecasting is the main task, but equally important is understanding the correlations between time series, which led me to choose a Bayesian framework.

For this problem, I’ve experimented quite a bit with PyMC after reading @AlexAndorra’s (amazing) post on estimating latent presidential popularity across time with a Markov chain. I started by aggregating the data to the region level. However, the moment I start going to deeper hierarchical levels, PyMC does not scale well, and most of the time, the model does not sample at all. I checked Sandra’s Scalable Bayesian Modelling repo where I tried to leveraged the power of GPUs, but I kept running into Out of Memory errors. I also created a post on the PyMC forum, trying to learn more from the community when it comes to scaling PyMC models to large datasets, but it seems this is an area of active development.

Overall, I don’t have many features (global level, item effect, region effect, weekly seasonality, yearly seasonality, special days effect). After attending the Dynamax book club, I started thinking that maybe I could reframe the problem using SSM. I know I could use a deep learning model, but the interpretable and adaptive nature of SSMs makes them more interesting. My first attempt came from the Time Series Forecasting notebook I mentioned before, where they implement an SSM model called SGT (Seasonal Global Trend). I managed to extend it to handle multiple correlated time series (check here) but this model can only handle one type of seasonality at a time and I wasn’t able to break this constraint. Then I found Pyro’s tutorials on forecasting with SSMs. I’m currently checking them, and I’m surprised to see they can handle multiple time series that are modest in size, with the model running in a matter of minutes. They are using stochastic variational inference (SVI) and it seems this is related to what @ckrapu mentioned in this post when he had to train a big model.

So going back to the book club and Dynamax, I’m trying to learn how to leverage this library as well as JAX to tackle this interesting problem using SSMs. Still, there are lots of things to learn and try. I would be happy to hear your thoughts.

Hey @jroberayalas, I’ve been thinking about using dynamax for M5 data too. I think this is better for sts-jax rather than pure dynamax. I was planning on contributing this as an example, but I’d be really keen to collaborate if you’re interested. Happy to set up a quick call.

Following the notation of the README, we have

y_t = H_t z_t + u_t + \epsilon_t, \qquad \epsilon_t \sim \mathcal{N}(0, \sigma^2_t)

z_{t+1} = F_t z_t + R_t \eta_t, \qquad \eta_t \sim \mathcal{N}(0, Q_t)

We can make some simplifying assumptions:

  • \sigma_t = \sigma_h static measurement noise
  • Q_t = \sigma_q^2. This is now a vector of disturbance, the same length as the number of states \text{length}(z_t) = N_z, so each state has independent disturbances (no covariance in the AR effects). And as there are the same number of states as disturbances, \eta_t is the same length as z_t so R_t is identity.
  • F_t z_t= f \cdot z_t, where f is the diagonal of some matrix F (no cross-correlations between states). This should save us having to invert F in Kalman filter, because inverting a diagonal is easy. We can either set all elements of f to 1 (perfect correlation between timesteps) or estimate the correlation of each effect.
  • u_t = x_t^T \beta: regression component from external inputs.

Let’s say, following the M5 context, we have N_s stores and N_t timesteps. Let’s say each store has it’s own AR effect, and each store has a seasonality effect. Using seasonal dummies, we can have weekly seasonality. So have N_z = N_s (1 + 6), because we only need 6 dummies for a seasonality of 7.

Putting this together, and using bold for vectors and matrices because I prefer that, we have

\mathbf{y_t} = \mathbf{H_t} \mathbf{z_t} + \mathbf{u_t} + \mathbf{\epsilon_t}, \qquad \mathbf{\epsilon_t} \sim \mathcal{N}(0, \sigma_h^2)

\mathbf{z_{t+1}} = \mathbf{f} \cdot \mathbf{z_t} + \mathbf{\eta_t}, \qquad \mathbf{\eta_t} \sim \mathcal{N}(0, \mathbf{\sigma_q^2})

\mathbf{u_t} = \mathbf{X_t} \mathbf{\beta}

And the form of z_t is something like

\begin{bmatrix} \text{AR effect, store 1} \\ \text{seasonal effect 1, store 1} \\ ... \\ \text{seasonal effect 6, store 1} \\ ... \\ \text{AR effect, store } N_s \\ \text{seasonal effect ,1 store } N_s \\ ... \\ \text{seasonal effect 6, store } N_s \end{bmatrix}

Then \mathbf{H_t}, which has shape N_s \times N_z, sums the states to the observation dimension \text{length}(y_t) = N_s, is something like (for N_s = 2, and the day of the week is such that the third seasonal component is selected)

\begin{bmatrix} 1 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \end{bmatrix}

The 1st column means we select store 1’s AR component. Col 4 is the same as we are selecting the 3rd seasonal effect for store 1. Col 8 means we select store 2’s AR component. Col 11 means we select the 3rd seasonal effect for store 2.

At the moment, this is all independent except for the common \sigma_h. But we can extend this to share information between stores in a few ways:

  • Common AR and seasonal effects between stores
  • hierarchical priors on \mathbf{\sigma_q}, \beta
  • Allowing covariance between disturbances (matrix Q), or cross-correlations between AR steps (matrix \textbf{F})

Although I don’t know how easy it is to control priors in dynamax as it isn’t a PPL. Does anyone have experience with this?

Let me know what you all think, but I’d been keen to contribute this.

3 Likes