Hey @jroberayalas, I’ve been thinking about using dynamax for M5 data too. I think this is better for sts-jax rather than pure dynamax. I was planning on contributing this as an example, but I’d be really keen to collaborate if you’re interested. Happy to set up a quick call.
Following the notation of the README, we have
y_t = H_t z_t + u_t + \epsilon_t, \qquad \epsilon_t \sim \mathcal{N}(0, \sigma^2_t)
z_{t+1} = F_t z_t + R_t \eta_t, \qquad \eta_t \sim \mathcal{N}(0, Q_t)
We can make some simplifying assumptions:
-
\sigma_t = \sigma_h static measurement noise
-
Q_t = \sigma_q^2. This is now a vector of disturbance, the same length as the number of states \text{length}(z_t) = N_z, so each state has independent disturbances (no covariance in the AR effects). And as there are the same number of states as disturbances, \eta_t is the same length as z_t so R_t is identity.
-
F_t z_t= f \cdot z_t, where f is the diagonal of some matrix F (no cross-correlations between states). This should save us having to invert F in Kalman filter, because inverting a diagonal is easy. We can either set all elements of f to 1 (perfect correlation between timesteps) or estimate the correlation of each effect.
-
u_t = x_t^T \beta: regression component from external inputs.
Let’s say, following the M5 context, we have N_s stores and N_t timesteps. Let’s say each store has it’s own AR effect, and each store has a seasonality effect. Using seasonal dummies, we can have weekly seasonality. So have N_z = N_s (1 + 6), because we only need 6 dummies for a seasonality of 7.
Putting this together, and using bold for vectors and matrices because I prefer that, we have
\mathbf{y_t} = \mathbf{H_t} \mathbf{z_t} + \mathbf{u_t} + \mathbf{\epsilon_t}, \qquad \mathbf{\epsilon_t} \sim \mathcal{N}(0, \sigma_h^2)
\mathbf{z_{t+1}} = \mathbf{f} \cdot \mathbf{z_t} + \mathbf{\eta_t}, \qquad \mathbf{\eta_t} \sim \mathcal{N}(0, \mathbf{\sigma_q^2})
\mathbf{u_t} = \mathbf{X_t} \mathbf{\beta}
And the form of z_t is something like
\begin{bmatrix}
\text{AR effect, store 1} \\
\text{seasonal effect 1, store 1} \\
... \\
\text{seasonal effect 6, store 1} \\
... \\
\text{AR effect, store } N_s \\
\text{seasonal effect ,1 store } N_s \\
... \\
\text{seasonal effect 6, store } N_s
\end{bmatrix}
Then \mathbf{H_t}, which has shape N_s \times N_z, sums the states to the observation dimension \text{length}(y_t) = N_s, is something like (for N_s = 2, and the day of the week is such that the third seasonal component is selected)
\begin{bmatrix}
1 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 1 & 0 & 0 & 0 & 0
\end{bmatrix}
The 1st column means we select store 1’s AR component. Col 4 is the same as we are selecting the 3rd seasonal effect for store 1. Col 8 means we select store 2’s AR component. Col 11 means we select the 3rd seasonal effect for store 2.
At the moment, this is all independent except for the common \sigma_h. But we can extend this to share information between stores in a few ways:
- Common AR and seasonal effects between stores
- hierarchical priors on \mathbf{\sigma_q}, \beta
- Allowing covariance between disturbances (matrix Q), or cross-correlations between AR steps (matrix \textbf{F})
Although I don’t know how easy it is to control priors in dynamax as it isn’t a PPL. Does anyone have experience with this?
Let me know what you all think, but I’d been keen to contribute this.