"""
test multiple regressors with different hyperparameters and aggregation strategy
"""
from autogluon.timeseries import TimeSeriesDataFrame
from autogluon.timeseries.metrics import MASE, SMAPE
from residual_chronos.Predictor import Predictor
import pandas as pd
import numpy as np
import random
import torch

# Set random seeds for reproducibility
RANDOM_SEED = 100
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

# Load example data
print("Loading data...")
data = TimeSeriesDataFrame.from_path(
    "https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/test.csv",
)
print("Original data:")
print(data.head())
print("\nData columns:", data.columns.tolist())
print("Data shape:", data.shape)

# Split data into train and test sets
prediction_length = 8
train_data, test_data = data.train_test_split(prediction_length=prediction_length)

# Explain the data structure
print("\n===== DATA STRUCTURE EXPLANATION =====")
print("Target column:", "unit_sales")
print("Known covariates columns:", ["scaled_price", "promotion_email", "promotion_homepage"])
print("\nTrain data contains historical information:")
print(train_data.head(3))

print("\nTest data contains future data for evaluation:")
print(test_data.head(3))

# Extract context and known covariates for prediction
context = test_data.slice_by_timestep(None, -prediction_length)
known_covariates = test_data.slice_by_timestep(-prediction_length, None)
static_features = test_data.slice_by_timestep(-prediction_length, None).static_features

print("\nContext data (for prediction):")
print(context.head(3))
print("\nKnown covariates data (for forecast horizon):")
print(known_covariates.head(3))

# Initialize the Predictor model
print("\n=== Initializing model... ===")
model = Predictor(
    prediction_length=prediction_length,
    target="unit_sales",
    known_covariates_names=["scaled_price", "promotion_email", "promotion_homepage"],
    # known_covariates_real=["scaled_price"],
    # known_covariates_cat=["promotion_email", "promotion_homepage"],
    eval_metric="MAE",  # This gets mapped to mean_squared_error for the regressor
    regressor_types=["ETS", "AutoARIMA", "XGB", "GBM",], # "XGB", # "CAT",
    regressor_hyperparameters={
        "GBM": {
            "learning_rate": 0.1,
            "max_depth": 6,
            "n_estimators": 100
        },
        "XGB": {
            "learning_rate": 0.1,
            "max_depth": 6,
            "n_estimators": 100
        },
        "AutoARIMA": {
            "seasonal": True,          # default=True, use False for non‑seasonal
            "d": None, "D": None,      # let tests decide differencing orders
            "max_p": 5, "max_P": 2,    # search limits
        },
        "ETS": {
            "error": "add",            # additive errors
            "trend": "add",            # Holt additive trend
            "seasonal": "add",         # "add" or "mul"
            "damped_trend": True,      # optional γ damping
            "season_length": 12        # monthly seasonality
        }
    },
    regressor_fit_time_fraction=0.5,  # Allocate 60% of time to regressor
    regressor_validation_fraction=0.1,  # Use 10% of data for validation
    aggregation_strategy=("spa", {"sigma": 0.1, "normalizer": "softmax"})
    # aggregation_strategy="equal"
)


# Approach 1: Train on the full data
print("\n=== Training model on full data... ===")
model.fit(
    train_data=train_data,
    time_limit=240, # 60, #
    enable_ensemble=False,
    fine_tune=False,
)

# # Make predictions 
# print("\n=== Generating forecasts... ===")
# forecasts = model.predict(
#     data=context,
#     known_covariates=known_covariates,
#     static_features=static_features
# )

# print("\nForecast sample:")
# print(forecasts.head())
# print("\nForecast columns:", forecasts.columns.tolist())
# print("Forecast shape:", forecasts.shape)

# Evaluate model
print("\n=== Performance Metrics ===")

metric_answer = {
    "MASE": -1.0553,
    "SMAPE": -0.4296,
}

evaluation_results = model.evaluate(test_data)
for metric_name, score in evaluation_results.items():
    print(f"{metric_name}: {score:.4f}")
    if metric_name in metric_answer:
        assert abs(score - metric_answer[metric_name]) < 0.0001
    
