{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4b323a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!git clone -b update-gluonts https://github.com/time-series-foundation-models/lag-llama/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bea538ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "c:\\Users\\EJ17UF\\OneDrive - Aalborg Universitet\\Desktop\\AAAI\\lag-llama\n"
     ]
    }
   ],
   "source": [
    "cd lag-llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fb20f732",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install -r requirements.txt  # this could take some time # ignore the errors displayed by colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7355278a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install -U torch torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b0cdb4c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir lag-llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1a5ba550",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import islice\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib.dates as mdates\n",
    "\n",
    "import torch\n",
    "from gluonts.evaluation import make_evaluation_predictions, Evaluator\n",
    "from gluonts.dataset.repository.datasets import get_dataset\n",
    "\n",
    "from gluonts.dataset.pandas import PandasDataset\n",
    "from gluonts.dataset.common import ListDataset\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from lag_llama.gluon.estimator import LagLlamaEstimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "da9b249d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import logging\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "logging.getLogger(\"lightning\").setLevel(logging.ERROR)\n",
    "warnings.filterwarnings(\"ignore\", message=\".*non-tuple sequence for multidimensional indexing.*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b9b24921",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from types import ModuleType\n",
    "\n",
    "# Create dummy module hierarchy\n",
    "def create_dummy_module(module_path):\n",
    "    \"\"\"\n",
    "    Create a dummy module hierarchy for the given path.\n",
    "    Returns the leaf module.\n",
    "    \"\"\"\n",
    "    parts = module_path.split('.')\n",
    "    current = ''\n",
    "    parent = None\n",
    "\n",
    "    for part in parts:\n",
    "        current = current + '.' + part if current else part\n",
    "        if current not in sys.modules:\n",
    "            module = ModuleType(current)\n",
    "            sys.modules[current] = module\n",
    "            if parent:\n",
    "                setattr(sys.modules[parent], part, module)\n",
    "        parent = current\n",
    "\n",
    "    return sys.modules[module_path]\n",
    "\n",
    "# Create the dummy gluonts module hierarchy\n",
    "gluonts_module = create_dummy_module('gluonts.torch.modules.loss')\n",
    "\n",
    "# Create dummy classes for the specific loss functions\n",
    "class DistributionLoss:\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        pass\n",
    "\n",
    "    def __call__(self, *args, **kwargs):\n",
    "        return 0.0\n",
    "\n",
    "    def __getattr__(self, name):\n",
    "        return lambda *args, **kwargs: None\n",
    "\n",
    "class NegativeLogLikelihood:\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        pass\n",
    "\n",
    "    def __call__(self, *args, **kwargs):\n",
    "        return 0.0\n",
    "\n",
    "    def __getattr__(self, name):\n",
    "        return lambda *args, **kwargs: None\n",
    "\n",
    "# Add the specific classes to the module\n",
    "gluonts_module.DistributionLoss = DistributionLoss\n",
    "gluonts_module.NegativeLogLikelihood = NegativeLogLikelihood"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e0535802",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('5G_millisecond.csv',parse_dates=True)\n",
    "\n",
    "df['DATE'] = pd.to_datetime(df['DATE'])\n",
    "df.set_index('DATE', inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e49a58bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set numerical columns as float32\n",
    "for col in df.columns:\n",
    "    # Check if column is not of string type\n",
    "    if df[col].dtype != 'object' and pd.api.types.is_string_dtype(df[col]) == False:\n",
    "        df[col] = df[col].astype('float32')\n",
    "\n",
    "train_end = round(len(df) * 0.8)\n",
    "\n",
    "train = PandasDataset(df[:train_end], freq=\"100ms\", target=\"mac_dl_brate\")\n",
    "test = PandasDataset(df[train_end:], freq=\"100ms\", target=\"mac_dl_brate\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bd826670",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction_length=96\n",
    "context_length = 5\n",
    "num_samples = 1\n",
    "device = torch.device(\"cpu\")\n",
    "batch_size = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "510e290d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_lag_llama_predictor(prediction_length, device, context_length=32, use_rope_scaling=False, num_samples=100):\n",
    "    ckpt = torch.load(\"lag-llama.ckpt\", map_location=device, weights_only=False)\n",
    "    estimator_args = ckpt[\"hyper_parameters\"][\"model_kwargs\"]\n",
    "\n",
    "    rope_scaling_arguments = {\n",
    "        \"type\": \"linear\",\n",
    "        \"factor\": max(1.0, (context_length + prediction_length) / estimator_args[\"context_length\"]),\n",
    "    }\n",
    "\n",
    "\n",
    "    estimator = LagLlamaEstimator(\n",
    "        ckpt_path=\"lag-llama.ckpt\",\n",
    "        prediction_length=prediction_length,\n",
    "        context_length=context_length,\n",
    "        input_size=estimator_args[\"input_size\"],\n",
    "        n_layer=estimator_args[\"n_layer\"],\n",
    "        n_embd_per_head=estimator_args[\"n_embd_per_head\"],\n",
    "        n_head=estimator_args[\"n_head\"],\n",
    "        scaling=estimator_args[\"scaling\"],\n",
    "        time_feat=estimator_args[\"time_feat\"],\n",
    "        nonnegative_pred_samples=True,\n",
    "        rope_scaling=rope_scaling_arguments if use_rope_scaling else None,\n",
    "        batch_size=1,\n",
    "        num_parallel_samples=1,  \n",
    "    )\n",
    "\n",
    "    lightning_module = estimator.create_lightning_module()\n",
    "    transformation = estimator.create_transformation()\n",
    "    predictor = estimator.create_predictor(transformation, lightning_module)\n",
    "\n",
    "    return predictor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7193f250",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rolling_evaluation(dataset, predictor, prediction_length=96, stride=1):\n",
    "    ts_entry = next(iter(dataset))  \n",
    "    values = ts_entry[\"target\"]\n",
    "\n",
    "    start_time = (\n",
    "        ts_entry[\"start\"].to_timestamp()\n",
    "        if isinstance(ts_entry[\"start\"], pd.Period)\n",
    "        else pd.Timestamp(ts_entry[\"start\"])\n",
    "    )\n",
    "\n",
    "    index = pd.date_range(start=start_time, periods=len(values), freq=dataset.freq)\n",
    "\n",
    "    contexts = []\n",
    "    for start in range(0, len(values) - prediction_length, stride):\n",
    "        context = values[: start + prediction_length]\n",
    "        contexts.append({\"target\": context, \"start\": start_time})\n",
    "\n",
    "    tmp_dataset = ListDataset(contexts, freq=dataset.freq)\n",
    "    forecasts = list(predictor.predict(tmp_dataset))\n",
    "\n",
    "    results = []\n",
    "    for i, forecast in enumerate(forecasts):\n",
    "        forecast_mean = forecast.mean_ts.values\n",
    "        context_end = i + prediction_length\n",
    "\n",
    "        max_len = min(prediction_length, len(values) - context_end)\n",
    "\n",
    "        forecast_ts = index[context_end : context_end + max_len]\n",
    "        forecast_mean = forecast_mean[:max_len]\n",
    "        actual = values[context_end : context_end + max_len]\n",
    "\n",
    "        df_forecast = pd.DataFrame({\n",
    "            \"timestamp\": forecast_ts,\n",
    "            \"mean\": forecast_mean,\n",
    "            \"actual\": actual,\n",
    "        })\n",
    "        results.append(df_forecast)\n",
    "\n",
    "    return pd.concat(results, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f6739a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_lag_llama(seed, train, test, prediction_length, context_length, device):\n",
    "    torch.manual_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    predictor = build_lag_llama_predictor(\n",
    "        prediction_length=prediction_length,\n",
    "        context_length=context_length,\n",
    "        device=device\n",
    "    )\n",
    "\n",
    "    forecasts_df = rolling_evaluation(\n",
    "        test,\n",
    "        predictor=predictor,\n",
    "        prediction_length=prediction_length,\n",
    "        stride=1\n",
    "    )\n",
    "\n",
    "    train_values = []\n",
    "    for ts_entry in train: \n",
    "        train_values.extend(ts_entry[\"target\"])\n",
    "\n",
    "    train_values = np.array(train_values).reshape(-1, 1)\n",
    "    scaler_target = MinMaxScaler().fit(train_values)\n",
    "\n",
    "    actual_scaled = scaler_target.transform(forecasts_df[[\"actual\"]].values)\n",
    "    pred_scaled = scaler_target.transform(forecasts_df[[\"mean\"]].values)\n",
    "\n",
    "    rmse = np.sqrt(mean_squared_error(actual_scaled, pred_scaled))\n",
    "    mae = mean_absolute_error(actual_scaled, pred_scaled)\n",
    "\n",
    "    return rmse, mae"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d0b9fe6",
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = [42, 99, 123]  \n",
    "results = []\n",
    "\n",
    "for seed in seeds:\n",
    "    rmse, mae = evaluate_lag_llama(\n",
    "        seed,\n",
    "        train, \n",
    "        test, \n",
    "        prediction_length, \n",
    "        context_length, \n",
    "        device\n",
    "    )\n",
    "    results.append({\"seed\": seed, \"rmse\": rmse, \"mae\": mae})\n",
    "    print(f\"Seed {seed}: RMSE: {rmse:.4f}, MAE: {mae:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2fd1c9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Seed 42: RMSE: 0.0431, MAE: 0.0260\n",
      "Seed 99: RMSE: 0.0526, MAE: 0.0280\n",
      "Seed 123: RMSE: 0.0465, MAE: 0.0265\n",
      "\\Summary\n",
      "RMSE mean: 0.0474, std: 0.0039\n",
      "MAE  mean: 0.0268, std: 0.0009\n"
     ]
    }
   ],
   "source": [
    "for r in results:\n",
    "    print(f\"Seed {r['seed']}: RMSE: {r['rmse']:.4f}, MAE: {r['mae']:.4f}\")\n",
    "\n",
    "print(\"\\nSummary\")\n",
    "print(f\"RMSE mean: {np.mean([r['rmse'] for r in results]):.4f}, std: {np.std([r['rmse'] for r in results]):.4f}\")\n",
    "print(f\"MAE  mean: {np.mean([r['mae'] for r in results]):.4f}, std: {np.std([r['mae'] for r in results]):.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
