"""
test single regressor with default hyperparameters
"""
from autogluon.timeseries import TimeSeriesDataFrame
from autogluon.timeseries.metrics import MASE, SMAPE
from residual_chronos.Predictor import Predictor
import pandas as pd
import numpy as np
import random
import torch

# Set random seeds for reproducibility
RANDOM_SEED = 100
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

# Load example data
print("Loading data...")
data = TimeSeriesDataFrame.from_path("./tests/hopformer/data/store_sales_data.csv",)
print("Original data:")
print(data.head())

# Split data into train and test sets
prediction_length = 24
context_length = 512
train_data, test_data = data.train_test_split(prediction_length=prediction_length)

# Extract context and known covariates for prediction
context = test_data.slice_by_timestep(-(context_length + prediction_length), -prediction_length) # 192
known_covariates = test_data.slice_by_timestep(-prediction_length, None)

# Initialize the Predictor model
print("\n=== Initializing model... ===")
model = Predictor(
    prediction_length=prediction_length,
    target="target",
    known_covariates_real=["promotion", "temperature", "price"],
    eval_metric="MAE",  # This gets mapped to mean_squared_error for the regressor
    regressor_types=["LR", "XGB", "GBM", "CAT"], # "XGB", # "CAT",
    regressor_hyperparameters={
        "LR": {},
        "XGB": {},
        "GBM": {},
        "CAT": {},
    },
    bolt_model_path="bolt_small",
    random_seed=RANDOM_SEED,
    regressor_fit_time_fraction=0.5,  # Allocate 60% of time to regressor
    regressor_validation_fraction=0.1 , # Use 10% of data for validation
    aggregation_strategy=("SPA", {}),
    # verbosity=4,
)


# Approach 1: Train on the full data
print("\n=== Training model on full data... ===")
model.fit(
    train_data=train_data,
    time_limit=120, # 60, #
    enable_ensemble=False,
    fine_tune=False,
)

# Make predictions 
print("\n=== Generating forecasts... ===")
forecasts = model.predict(data=context, known_covariates=known_covariates,)

forecasts.to_csv("/home/magics/hdd/sky_ws/residual_ws/tests/hopformer/data/forecasts.csv")

