import matplotlib.pyplot as plt
import pandas as pd
import torch
import numpy as np
from tsfm.model.kairos import AutoModel

pipeline = AutoModel.from_pretrained(
    "./checkpoints/kairos-small"
)

df = pd.read_csv("./datasets/test.csv")
col_name = "Target"
df[col_name] = df[col_name].astype(float)


context_length = 2048
prediction_length = 96
forecast = pipeline(
    past_target=torch.tensor(df[col_name][0:context_length].values).float().unsqueeze(0),
    prediction_length=prediction_length,
    generation=True,
)
forecast = forecast["prediction_outputs"]
forecast_index = range(context_length, context_length + prediction_length)
low, median, high = np.quantile(forecast[0].detach().cpu().numpy(), [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(16, 4))
plt.plot(df[col_name][0:context_length], color="royalblue", label="historical data")

plt.plot(df[col_name][context_length:context_length+prediction_length], color="green", label="label")
plt.plot(forecast_index, median, color="tomato", label="forecast")
plt.legend()
plt.grid()
plt.savefig("pipeline.png")
plt.show()