import pandas as pd
import numpy as np
import torch

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error

from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split

from uni2ts.model.moirai import MoiraiForecast, MoiraiModule


# =========================================================
# 1. DATA LOADING + NORMALIZATION (+ item_id = 'A')
# =========================================================

def load_and_norm(path):
    df = pd.read_csv(path, index_col="date")

    # Normalize the first value column
    value_col = df.columns[0]
    scaler = StandardScaler()
    df[value_col] = scaler.fit_transform(df[[value_col]])

    # Rename to 'target' so GluonTS code is consistent
    if value_col != "target":
        df = df.rename(columns={value_col: "target"})

    # Add item_id column required by PandasDataset
    df["item_id"] = "A"

    return df


# =========================================================
# 2. EVALUATION WITH GLUONTS + MOIRAI
# =========================================================

def evaluate_moirai_on_df(df, horizons, ctx_length=168, patch_size=32,
                          num_samples=100, batch_size=32):
    """
    df: DataFrame with index as timestamp, columns ['target', 'item_id']
    horizons: list of forecast horizons (PDT)
    """
    # Convert to GluonTS PandasDataset
    ds = PandasDataset.from_long_dataframe(df, target="target", item_id="item_id")

    results = []

    for PDT in horizons:
        TEST = PDT  # we use last PDT steps as test, 1 rolling window

        if len(df) <= TEST:
            results.append({"horizon": PDT, "MSE": np.nan, "MAE": np.nan})
            continue

        print(f"Running for horizon = [{PDT}]")

        # Split last TEST steps as test set
        train, test_template = split(ds, offset=-TEST)

        # Rolling-window evaluation: here just 1 window (windows=1)
        test_data = test_template.generate_instances(
            prediction_length=PDT,
            windows=TEST // PDT,   # = 1
            distance=PDT,
        )
	
	model_path = ''
        # Prepare Moirai model
        model = MoiraiForecast(
            module=MoiraiModule.from_pretrained(model_path),
            prediction_length=PDT,
            context_length=ctx_length,
            patch_size=patch_size,
            num_samples=num_samples,
            target_dim=1,
            feat_dynamic_real_dim=ds.num_feat_dynamic_real,
            past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
        )

        predictor = model.create_predictor(batch_size=batch_size)
        forecasts = list(predictor.predict(test_data.input))

        y_true_list = []
        y_pred_list = []

        # test_data.label is an iterable of dicts with 'target'
        for forecast, label in zip(forecasts, test_data.label):
            # forecast is a GluonTS Forecast object: use median (quantile 0.5)
            median = forecast.quantile(0.5)      # shape: (PDT,)
            target = label["target"]            # shape: (PDT,)

            y_pred_list.append(median)
            y_true_list.append(target)

        y_pred = np.concatenate(y_pred_list)
        y_true = np.concatenate(y_true_list)

        mse = mean_squared_error(y_true, y_pred)
        mae = mean_absolute_error(y_true, y_pred)

        results.append({
            "horizon": PDT,
            "MSE": round(mse, 4),
            "MAE": round(mae, 4),
        })

    return results


# =========================================================
# 3. BENCHMARK OVER ALL DATASETS
# =========================================================

def run_benchmark(datasets, horizons):
    table = []

    for name, df in datasets.items():
        print("\n==============================")
        print(f"Running [{name}] dataset")
        print("==============================")

        results = evaluate_moirai_on_df(df, horizons)

        for r in results:
            table.append({
                "Dataset": name,
                "Horizon": r["horizon"],
                "MSE": r["MSE"],
                "MAE": r["MAE"],
            })

    return pd.DataFrame(table)


# =========================================================
# 4. LOAD YOUR DATASETS (WITH item_id ADDED)
# =========================================================

datasets = {
    "AUS_Elec_Demand": load_and_norm("energy_processed/australian_electricity_demand_dataset_processed.csv"),
    "Electricity_Weekly": load_and_norm("energy_processed/electricity_weekly_dataset_processed.csv"),

    "ETTh1": load_and_norm("energy_processed/ETTh1_processed.csv"),
    "ETTh2": load_and_norm("energy_processed/ETTh2_processed.csv"),
    "ETTm1": load_and_norm("energy_processed/ETTm1_processed.csv"),
    "ETTm2": load_and_norm("energy_processed/ETTm2_processed.csv"),

    "London_SmartMeters": load_and_norm("energy_processed/london_smart_meters_dataset_subset_processed.csv"),

    "Solar_10min": load_and_norm("energy_processed/solar_10_minutes_dataset_processed.csv"),

}


# =========================================================
# 5. RUN BENCHMARK
# =========================================================

horizons = [24, 48, 96]

results_table = run_benchmark(datasets, horizons)
print("\nFINAL RESULTS:")
print(results_table)
results_table.to_csv("energy_moirai.csv")