Moirai forecasting model
Since a while Large Language Models have been taking over the modeling scene. Among the classical NLP related ones, those dedicated for time series processing started to emerge. Moirai model is one of those. It is a foundational language model for time series forecasting developed by Salesforce. It is a Masked Encoder-based Universal Time Series Forecasting Transformer model pre-trained on LOTSA data.
You may find online a few tutorials showing how to interact with MOIRAI, however I still struggled to understand what is happening on different steps. Let me therefore break it down for you.
Contents
Dataset
MOIRAI model operates on PandasDataset structure from Gluonts package. It can be easily created from a regular pandas data frame, providing that it contains identifier column for your time series (cause more than one series can be ingested into the object).
For the demonstration purposes let’s use airpassangers dataset which can be downloaded i.e. from Kaggle. I changed column names to “date” and “value”.
|
1 2 3 4 5 6 7 |
import pandas as pd data = pd.read_csv( f"{DATA_PATH}/airpassengers.csv", parse_dates=["date"] ).rename(columns={"Month": "date", "#Passengers": "value"}) data |

Now let’s add artificial id column:
|
1 |
data["identifier"] = "series_id" |
Then we need to set data frame index using date column:
|
1 |
data = data.set_index("date") |
And now we can create our data object:
|
1 2 3 4 5 6 |
from gluonts.dataset.pandas import PandasDataset ds = PandasDataset.from_long_dataframe( data, target="value", item_id="identifier" ) ds |
|
1 2 3 4 5 6 7 8 9 |
PandasDataset< size=1, freq=M, num_feat_dynamic_real=0, num_past_feat_dynamic_real=0, num_feat_static_real=0, num_feat_static_cat=0, static_cardinalities=[] > |
Model setup
Now having the data ready, we can move to modeling stage. I will show how to perform train-testing modeling as well as prediction forward.
Let’s start by defining model parameters.
|
1 2 3 4 5 6 |
SIZE = "large" # model size: choose from {'small', 'base', 'large'} PATCH_SIZE = 32 # patch size: choose from {"auto", 8, 16, 32, 64, 128} BATCH_SIZE = 32 # batch size: any positive integer NUM_SAMPLES = 20 # model provides probabilistic forecast with given number of predicted data points for given date MOIRAI_MODEL = "moirai" CONTEXT_LENGTH = 100 |
Let’s consider prediction horizon of 3.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule prediction_length = 3 model = MoiraiForecast( module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.0-R-{SIZE}"), prediction_length=prediction_length, context_length=CONTEXT_LENGTH, patch_size=PATCH_SIZE, num_samples=NUM_SAMPLES, target_dim=1, feat_dynamic_real_dim=ds.num_feat_dynamic_real, past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real, ) |
Model weights are downloaded from huggingface hub.
Within we provide information about the number of exogenous features, that can be included in the data in the form of additional columns.
Train testing forecast
To perform train testing prediction we need to specify the testing set. Contrary to regular train-testing splits, were we cut the series at some point and use training part for model training and testing set for evaluation – we do here zero shot learning. Which basically means, that we enhance part of the series by providing prediction windows.
Let’s say that we want to perform testing using crossvalidation – for that let’s mark our test set to have 12 data points. With our prediction horizon of 3, it gives us 4 prediction windows.
|
1 2 3 4 5 6 7 8 9 10 11 12 |
from gluonts.dataset.split import split test_size = 12 train_data, test_template = split( ds, offset=-test_size ) test_data = test_template.generate_instances( prediction_length=prediction_length, windows=test_size//prediction_length, distance=prediction_length ) |
To understand what happens, we need to look into gluonts documentation.
Created training set is exactly the same as our full dataset.
|
1 |
train_data |
|
1 2 3 4 5 6 7 8 9 10 11 |
TrainingDataset( dataset=PandasDataset< size=1, freq=M, num_feat_dynamic_real=0, num_past_feat_dynamic_real=0, num_feat_static_real=0, num_feat_static_cat=0, static_cardinalities=[]>, splitter=OffsetSplitter(offset=-12) ) |
|
1 |
list(iter(train_data)) |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
[{'start': Period('1949-01', 'M'), 'target': array( [112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, 548, 559, 463, 407, 362, 405, 417, 391, 419, 461, 472, 535, 622, 606, 508, 461, 390, 432]), 'item_id': 'series_id'}] |
As we can see, we have monthly series starting on Jan 1949, containing values from 112 to 432.
Test dataset contains 4 elements: repeated training data, enhanced by specification of our four 3-elements prediction windows: starting on Jan, Apr, Jun and Sep 1960.
|
1 |
list(iter(test_data)) |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
[({'start': Period('1949-01', 'M'), 'target': array([112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, 548, 559, 463, 407, 362, 405]), 'item_id': 'series_id'}, {'start': Period('1960-01', 'M'), 'target': array([417, 391, 419]), 'item_id': 'series_id'}), ({'start': Period('1949-01', 'M'), 'target': array([112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, 548, 559, 463, 407, 362, 405, 417, 391, 419]), 'item_id': 'series_id'}, {'start': Period('1960-04', 'M'), 'target': array([461, 472, 535]), 'item_id': 'series_id'}), ({'start': Period('1949-01', 'M'), 'target': array( [112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, 548, 559, 463, 407, 362, 405, 417, 391, 419, 461, 472, 535]), 'item_id': 'series_id'}, {'start': Period('1960-07', 'M'), 'target': array([622, 606, 508]), 'item_id': 'series_id'}), ({'start': Period('1949-01', 'M'), 'target': array( [112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, 548, 559, 463, 407, 362, 405, 417, 391, 419, 461, 472, 535, 622, 606, 508]), 'item_id': 'series_id'}, {'start': Period('1960-10', 'M'), 'target': array([461, 390, 432]), 'item_id': 'series_id'})] |
However to obtain prediction, we provide only the “input” part of the testing set, which simply is a training set repeated 4 times, with last observations removed, but with additional metadata.
|
1 |
test_data.input |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
InputDataset( test_data=TestData( dataset=PandasDataset< size=1, freq=M, num_feat_dynamic_real=0, num_past_feat_dynamic_real=0, num_feat_static_real=0, num_feat_static_cat=0, static_cardinalities=[]>, splitter=OffsetSplitter(offset=-12), prediction_length=3, windows=4, distance=3, max_history=None) ) |
|
1 |
list(iter(test_data.input)) |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
[{'start': Period('1949-01', 'M'), 'target': array( [112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, 548, 559, 463, 407, 362, 405]), 'item_id': 'series_id'}, {'start': Period('1949-01', 'M'), 'target': array( [112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, 548, 559, 463, 407, 362, 405, 417, 391, 419]), 'item_id': 'series_id'}, {'start': Period('1949-01', 'M'), 'target': array( [112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, 548, 559, 463, 407, 362, 405, 417, 391, 419, 461, 472, 535]), 'item_id': 'series_id'}, {'start': Period('1949-01', 'M'), 'target': array( [112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, 548, 559, 463, 407, 362, 405, 417, 391, 419, 461, 472, 535, 622, 606, 508]), 'item_id': 'series_id'}] |
Now when running the training prediction, we obtain four 3-element predictions, each sampled 20 times, as we specified during model fetching. This nicely simulates crossvalidation scenario.
|
1 2 3 4 |
predictor = model.create_predictor(batch_size=32) forecasts = predictor.predict(test_data.input) forecasts = list(forecasts) forecasts |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
[gluonts.model.forecast.SampleForecast( info=None, item_id='series_id', samples=array([ [437.57245, 429.63123, 438.20776], [416.22205, 446.30884, 443.11096], [431.7821 , 483.75003, 373.63806], [428.66147, 487.99554, 433.95813], [411.8133 , 451.14062, 444.55682], [435.68 , 442.5996 , 423.2483 ], [401.62332, 446.01062, 413.58963], [415.2624 , 379.83978, 449.74814], [401.48392, 438.77094, 428.26733], [414.94492, 418.52576, 441.8598 ], [444.6018 , 403.11273, 396.1852 ], [412.6574 , 422.69406, 454.55582], [417.16632, 423.93097, 382.55084], [436.61234, 409.6562 , 406.20935], [387.13336, 455.55115, 410.9076 ], [457.82724, 467.40308, 457.6879 ], [411.2807 , 425.8558 , 465.39404], [411.954 , 442.20566, 417.07324], [393.5013 , 451.49994, 456.37463], [411.49817, 431.63434, 378.68427]], dtype=float32), start_date=Period('1960-01', 'M')), gluonts.model.forecast.SampleForecast( info=None, item_id='series_id', samples=array([ [419.3704 , 523.8678 , 553.00525], [437.50406, 443.61523, 503.47278], [430.98187, 381.5304 , 500.4603 ], [440.40656, 440.57074, 512.07367], [450.9077 , 422.37384, 471.7052 ], [419.15958, 440.92014, 576.2252 ], [501.21255, 476.12677, 512.54395], [421.51666, 432.55743, 543.7436 ], [399.6961 , 443.0387 , 490.68176], [408.81482, 457.81552, 507.8585 ], [400.4411 , 411.19855, 526.5586 ], [409.325 , 455.50098, 424.63824], [465.67444, 455.7219 , 433.64502], [422.8286 , 476.79425, 483.636 ], [419.61087, 415.78143, 558.2146 ], [444.60126, 429.725 , 443.5708 ], [436.8291 , 379.9837 , 492.20572], [426.7774 , 482.95575, 450.70685], [444.2102 , 468.6631 , 491.7443 ], [436.9636 , 439.00613, 498.2074 ]], dtype=float32), start_date=Period('1960-04', 'M')), gluonts.model.forecast.SampleForecast( info=None, item_id='series_id', samples=array([ [618.4667 , 627.18945, 493.9572 ], [607.5718 , 623.53845, 517.4912 ], [589.823 , 640.5636 , 510.2801 ], [574.3552 , 614.1211 , 510.2019 ], [581.9795 , 566.5288 , 554.9214 ], [587.3402 , 626.8988 , 465.88983], [584.2534 , 643.9759 , 501.25635], [582.8181 , 618.1798 , 593.0267 ], [563.12897, 606.55005, 484.44403], [599.2088 , 611.0701 , 585.46106], [578.9343 , 603.1611 , 493.3283 ], [569.2555 , 533.3191 , 473.9502 ], [561.60925, 597.36743, 564.6122 ], [588.75977, 615.27576, 537.1035 ], [585.4053 , 627.9608 , 534.87006], [581.15967, 644.95874, 538.1704 ], [586.85583, 579.3374 , 474.84747], [562.66907, 637.06165, 567.2312 ], [590.9659 , 579.7815 , 475.55676], [600.19275, 612.7903 , 517.04846]], dtype=float32), start_date=Period('1960-07', 'M')), gluonts.model.forecast.SampleForecast( info=None, item_id='series_id', samples=array([ [387.834 , 353.92053, 467.49762], [404.7618 , 356.21448, 471.4069 ], [420.63995, 395.79037, 424.6709 ], [422.08197, 402.552 , 438.00925], [404.53522, 432.80875, 422.32242], [447.36456, 439.0426 , 420.52994], [438.58795, 379.26248, 483.2064 ], [435.47885, 422.79645, 507.21664], [470.11078, 346.604 , 413.33658], [409.01697, 401.09283, 449.74396], [461.0573 , 432.02374, 382.6682 ], [413.03162, 470.3768 , 441.9055 ], [421.33926, 359.91812, 402.88885], [415.94687, 469.7745 , 391.74246], [376.75748, 446.7992 , 467.32498], [380.22852, 404.19574, 423.70697], [487.54486, 465.50043, 446.1779 ], [460.8803 , 412.78766, 481.14355], [383.18826, 425.47327, 441.91144], [447.75095, 441.40387, 465.30792]], dtype=float32), start_date=Period('1960-10', 'M'))] |
“samples” property allows you to access those values to calculate average predictions, or required percentiles.
|
1 |
forecasts[0].samples |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
array([[437.57245, 429.63123, 438.20776], [416.22205, 446.30884, 443.11096], [431.7821 , 483.75003, 373.63806], [428.66147, 487.99554, 433.95813], [411.8133 , 451.14062, 444.55682], [435.68 , 442.5996 , 423.2483 ], [401.62332, 446.01062, 413.58963], [415.2624 , 379.83978, 449.74814], [401.48392, 438.77094, 428.26733], [414.94492, 418.52576, 441.8598 ], [444.6018 , 403.11273, 396.1852 ], [412.6574 , 422.69406, 454.55582], [417.16632, 423.93097, 382.55084], [436.61234, 409.6562 , 406.20935], [387.13336, 455.55115, 410.9076 ], [457.82724, 467.40308, 457.6879 ], [411.2807 , 425.8558 , 465.39404], [411.954 , 442.20566, 417.07324], [393.5013 , 451.49994, 456.37463], [411.49817, 431.63434, 378.68427]], dtype=float32) |
Prediction ahead
To perform forecast ahead, we do not need a testing dataset. And as in previous case, we perform prediction using the proper part of the training set sample, we now provide the full sample instead.|
1 2 |
forecasts = predictor.predict(ds) forecasts = list(forecasts) |
What we get is a 3 points prediction for Jan-Mar 1961, each sampled 20 times.
|
1 |
forecasts |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
[gluonts.model.forecast.SampleForecast( info=None, item_id='series_id', samples=array([ [427.76852, 484.938 , 511.5961 ], [444.39307, 457.0478 , 516.5437 ], [455.9989 , 430.7521 , 459.3556 ], [449.44437, 441.53247, 466.0277 ], [443.8015 , 469.7965 , 445.30768], [408.37363, 432.83743, 447.24597], [438.03394, 430.1796 , 435.22168], [401.77078, 467.92838, 463.85724], [406.02524, 471.80396, 437.33307], [455.00793, 447.15952, 449.0818 ], [430.5515 , 482.67798, 457.6836 ], [450.1662 , 463.2627 , 505.51727], [435.6454 , 427.07227, 448.88635], [388.99982, 447.7803 , 434.69858], [431.01733, 440.27377, 436.52856], [484.6833 , 495.8401 , 468.60553], [452.8319 , 423.40204, 447.73932], [439.3493 , 403.3471 , 454.06042], [451.11032, 431.2359 , 464.2246 ], [479.24243, 439.86835, 495.58832]], dtype=float32), start_date=Period('1961-01', 'M'))] |
And this is how this model works.