import os
import numpy as np
import pandas as pd
import termcolor
import yaml
import argparse
from tqdm import tqdm
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any, Union, Tuple
from gift_eval.wrapper import ElectricityWrapper, GiftEvalWrapper, M5Wrapper, TrafficWrapper, ExchangeWrapper, EPFWrapper, IllnessWrapper, NonLinearSaleSyntheticWrapper, SaleSyntheticWrapper
from autogluon.timeseries import TimeSeriesPredictor
from autogluon.timeseries.metrics import MASE, MSE, SMAPE, MAPE

from residual_chronos import Predictor
from residual_chronos.Regressor import TimeSeriesRegressor
from datetime import datetime


@dataclass(frozen=True)
class EvaluationConfig:
    """Configuration for evaluating time series regressors.
    
    Args:
        dataset_name: Name of the dataset to evaluate.
        term: Forecasting horizon ("short" or "long").
        to_univariate: Whether to convert to univariate data.
        model_name: Name for the model in results.
        context_length: Number of time steps for context.
        result_folder: Folder to save evaluation results.
        regressor_config: Dictionary with regressor model configurations.
        regressor_types: List of regressor types to use.
        known_covariates_names: List of covariates that are known in advance.
        bolt_model_path: Path to bolt model or "bolt_small" for default.
        fine_tune: Whether to fine-tune the model.
        fine_tune_steps: Number of steps for fine-tuning.
        use_lora: Whether to use LoRA for fine-tuning.
        time_limit: Time limit in seconds for training.
        metrics: List of metrics to evaluate.
        is_plot: Whether to generate forecast plots.
        slice_start: Starting index for slicing training data.
        timestamp: Timestamp string for identifying the run.
        test_prediction_length: Number of time steps for test prediction.
    """
    dataset_name: str
    term: str = "short"
    to_univariate: bool = False
    target_column: str = "target"
    model_name: str = "Chronos"
    context_length: int = 512
    result_folder: str = "results"
    regressor_config: Dict[str, Dict[str, Any]] = field(default_factory=dict)
    regressor_types: List[str] = field(default_factory=list)
    aggregation_strategy: Union[str, Tuple[str, Dict[str, Any]]] = "equal"
    known_covariates_names: List[str] = field(default_factory=list)
    known_covariates_real: List[str] = field(default_factory=list)
    known_covariates_cat: List[str] = field(default_factory=list)
    static_features_cat: List[str] = field(default_factory=list)
    static_features_real: List[str] = field(default_factory=list)
    past_covariates_real: List[str] = field(default_factory=list)
    past_covariates_cat: List[str] = field(default_factory=list)
    bolt_model_path: Optional[str] = None
    fine_tune: bool = False
    fine_tune_steps: int = 0
    use_lora: bool = False
    time_limit: int = 3600
    metrics: List[str] = field(default_factory=lambda: ["MSE", "MASE", "SMAPE", "MAPE"])
    is_plot: bool = False
    slice_start: Optional[int] = None
    timestamp: str = field(default_factory=lambda: datetime.now().strftime("%Y%m%d_%H%M%S"))
    add_date_features: bool = False
    model_option: str = "default"
    aggregation_train_length_times: int = 4
    test_prediction_length: Optional[int] = None

def tune_config(config_dict: Dict[str, Any], model: str = "default", regressor_types: List[str] = None) -> Dict[str, Any]:
    """Tune the configuration based on the model variant.
    
    Args:
        config_dict: Dictionary with the configuration.
        model: Variant of the model to evaluate.
    """
    model_options = [
        "default",
        "regressor-spa",
        "regressor-equal",
        "regressor-singlebest",
        "regressor-linear",
        "chronos-0shot",
        "chronos-ft",
        "chronos-ft-lora",
        "hopformer-0shot",
        "hopformer-ft",
        "hopformer-ft-lora",
        "PatchTST",
        "TemporalFusionTransformer",
        "DLinear",
        "SimpleFeedForward",
        "AutoARIMA",
        "AutoCES",
        "AutoETS",
        "SeasonalAverage",
        "SeasonalNaive",
        "Naive",
        "NN_TORCH",
        "CAT",
        "FASTAI",
        "GBM",
        "LR",
        "RF",
        "XGB",
        "LGBM",
        "KNN",
        "MLP",
        "MLP_TORCH",
        "MLP_TORCH_TORCH",
        
    ]
    if model == "default":
        return config_dict
    
    assert model in model_options, f"Invalid model variant: {model}. Valid variants are: {model_options}"

    # Custom models
    if model in ["PatchTST", "TemporalFusionTransformer", "SimpleFeedForward", "DLinear", "AutoARIMA", "AutoCES", "AutoETS", "SeasonalAverage", "SeasonalNaive", "Naive"]:
        config_dict["model"]["name"] =model
        config_dict["model"]["regressor_types"] = [model]
        assert model in config_dict["model"]["regressor_config"], f"{model} should be in regressor_config"
        config_dict["model"]["aggregation_strategy"] = {
            "name": "equal",
            "config": {}
        }
        config_dict["model"]["bolt_model_path"] = None
        return config_dict
    

    if regressor_types is not None:
        config_dict["model"]["regressor_types"] = regressor_types
        valid_regressor_types = config_dict["model"]["regressor_types"]
        assert all(t in valid_regressor_types for t in regressor_types), f"Invalid regressor types: {regressor_types}, valid regressor types: {valid_regressor_types}"
        print(f"Regressor types: {config_dict['model']['regressor_types']}")

    if "regressor-" in model:
        config_dict["model"]["name"] = "CrossSectionalRegressor"
        assert len(config_dict["model"]["regressor_types"])>0, "CrossSectionalRegressor should have at least one regressor."
        config_dict["model"]["bolt_model_path"] = None

    if model == "regressor-spa":
        config_dict["model"]["aggregation_strategy"] = {
            "name": "spa",
            "config": {}
        }
    elif model == "regressor-equal":
        config_dict["model"]["aggregation_strategy"] = {
            "name": "equal",
            "config": {}
        }
    elif model == "regressor-singlebest":
        config_dict["model"]["name"] = "CrossSectionalRegressor"
        config_dict["model"]["aggregation_strategy"] = {
            "name": "singlebest",
            "config": {}
        }
    elif model == "regressor-linear":
        config_dict["model"]["aggregation_strategy"] = {
            "name": "linear",
            "config": {}    
        }
    
    if "chronos-" in model:
        config_dict["model"]["name"] = "Chronos"
        config_dict["model"]["regressor_types"] = []
        config_dict["model"]["bolt_model_path"] = "bolt_small"

    if "hopformer-" in model:
        config_dict["model"]["name"] = "Hopformer"
        assert len(config_dict["model"]["regressor_types"])>1, "CrossSectionalRegressor should have at least one regressor."
        config_dict["model"]["bolt_model_path"] = "bolt_small"
        config_dict["model"]["aggregation_strategy"] = {
            "name": "spa",
            "config": {}
        }

    if "-ft" in model:
        if "-lora" in model:
            config_dict["model"]["use_lora"] = True
            config_dict["model"]["fine_tune"] = True
            if "fine_tune_steps" in config_dict["model"]:
                if config_dict["model"]["fine_tune_steps"] == 0:
                    config_dict["model"]["fine_tune_steps"] = 3000
                else:
                    pass
            else:
                config_dict["model"]["fine_tune_steps"] = 3000
            # config_dict["model"]["eval_during_fine_tune"] = True
        else:
            config_dict["model"]["use_lora"] = False
            config_dict["model"]["fine_tune"] = True
            if "fine_tune_steps" in config_dict["model"]:
                if config_dict["model"]["fine_tune_steps"] == 0:
                    config_dict["model"]["fine_tune_steps"] = 3000
                else:
                    pass
            else:
                config_dict["model"]["fine_tune_steps"] = 3000
            # config_dict["model"]["eval_during_fine_tune"] = True
    elif "-0shot" in model:
        config_dict["model"]["fine_tune"] = False
        config_dict["model"]["fine_tune_steps"] = 0
        config_dict["model"]["use_lora"] = False

    return config_dict


def load_config(config_file: str, model_option: str = "default", regressor_types: List[str] = None) -> EvaluationConfig:
    """Load configuration from YAML file.
    
    Args:
        config_file: Path to the YAML configuration file.
        
    Returns:
        EvaluationConfig object with the loaded configuration.
        
    Raises:
        FileNotFoundError: If the configuration file doesn't exist.
        yaml.YAMLError: If there's an error parsing the YAML file.
    """
    with open(config_file, 'r') as f:
        config_dict = yaml.safe_load(f)
    
    config_dict = tune_config(config_dict, model_option, regressor_types)
    
    # Extract relevant configuration sections
    dataset_config = config_dict.get('data', {})
    model_config = config_dict.get('model', {})
    evaluation_config = config_dict.get('evaluation', {})

    aggregation_strategy = model_config.get('aggregation_strategy', "equal")
    aggregation_strategy = (aggregation_strategy.get("name", "equal"), aggregation_strategy.get("config", {}))

    model_name = model_config.get('name', 'Chronos')
    bolt_model_path = model_config.get('bolt_model_path', None)

    if bolt_model_path:
        if model_name.lower() == "hopformer":
            assert aggregation_strategy[0] == "spa", "Bolt model is only supported for SPA aggregation strategy. Temporary force to unify Hopformer."
        elif model_name.lower() == "chronos":
            assert model_config.get('regressor_types', []) == [], "Chronos does not support custom regressors."
        else:
            raise ValueError(f"Invalid model name: {model_name} with bolt model {bolt_model_path}.")
    
    if model_name.lower() == "chronos":
        assert model_config.get('regressor_types', []) == [], "Chronos does not support custom regressors."
    elif model_name.lower() == "crosssectionalregressor":
        assert bolt_model_path is None, "CrossSectionalRegressor does not support bolt model."
        assert all(r in model_config.get('regressor_config', {}) for r in model_config.get('regressor_types', [])), "All regressors should be in regressor_config"
    elif model_name.lower() == "hopformer":
        assert model_config.get('regressor_types', []) != [], "HopFormer does not support custom regressors."
        assert aggregation_strategy[0] == "spa", "HopFormer only supports SPA aggregation strategy."
        assert bolt_model_path is not None, "HopFormer requires a bolt model."
    elif model_name.lower() in ["patchtst", "temporalfusiontransformer", "simplefeedforward", "dlinear", "autoarima", "autoces", "autoets", "seasonalaverage", "seasonalnaive", "naive"]:
        assert model_config.get('regressor_types', []) == [model_name], "These models do not support custom regressors."
    else:
        raise ValueError(f"Invalid model name: {model_name}")
    
    # Create configuration object
    config = EvaluationConfig(
        dataset_name=dataset_config.get('source', 'covid_deaths'),
        term=dataset_config.get('term', 'short'),
        to_univariate=dataset_config.get('to_univariate', False),
        target_column=dataset_config.get('target_column', 'target'),
        model_name=model_config.get('name', 'Chronos'),
        context_length=dataset_config.get('context_length', 512),
        result_folder=evaluation_config.get('result_folder', 'results'),
        regressor_config=model_config.get('regressor_config', {}),
        regressor_types=model_config.get('regressor_types', []),
        known_covariates_names=dataset_config.get('known_covariates_names', None),
        known_covariates_real=dataset_config.get('known_covariates_real', None),
        known_covariates_cat=dataset_config.get('known_covariates_cat', None),
        static_features_cat=dataset_config.get('static_features_cat', None),
        static_features_real=dataset_config.get('static_features_real', None),
        past_covariates_real=dataset_config.get('past_covariates_real', None),
        past_covariates_cat=dataset_config.get('past_covariates_cat', None),
        bolt_model_path=model_config.get('bolt_model_path', None),
        fine_tune=model_config.get('fine_tune', False),
        fine_tune_steps=model_config.get('fine_tune_steps', 0),
        use_lora=model_config.get('use_lora', False),
        time_limit=model_config.get('time_limit', 3600),
        metrics=evaluation_config.get('metrics', ["MSE", "MASE", "SMAPE", "MAPE"]),
        is_plot=evaluation_config.get('is_plot', False),
        slice_start=dataset_config.get('slice_start', None),
        aggregation_strategy=aggregation_strategy,
        add_date_features=dataset_config.get('add_date_features', False),
        model_option=model_option,
        aggregation_train_length_times=model_config.get('aggregation_train_length_times', 4),
        test_prediction_length=dataset_config.get('test_prediction_length', None)
    )
    print(f"Loaded configuration: {config}")
    return config_dict, config


def get_metric_objects(metric_names: List[str]) -> Tuple[List[Any], List[str]]:
    """Get metric objects based on metric names.
    
    Args:
        metric_names: List of metric names.
        
    Returns:
        Tuple containing list of metric objects and their names.
    """
    metric_mapping = {
        "MSE": MSE(),
        "MASE": MASE(),
        "SMAPE": SMAPE(),
        "MAPE": MAPE()
    }
    
    metrics = []
    valid_names = []
    
    for name in metric_names:
        if name in metric_mapping:
            metrics.append(metric_mapping[name])
            valid_names.append(name)
    
    return metrics, valid_names


def evaluate_predictor(predictor, test_dataset, context_length, prediction_length, 
                       metrics=None, metric_names=None, is_plot=False, config=None, test_prediction_length=None):
    """Evaluate a trained predictor on test data.
    
    Args:
        predictor: Trained TimeSeriesPredictor.
        test_dataset: GiftEvalWrapper.
        context_length: Context length for test data.
        prediction_length: Number of time steps predicted.
        metrics: List of metric objects.
        metric_names: List of metric names.
        is_plot: Whether to plot forecasts.
        config: Evaluation configuration.
        
    Returns:
        DataFrame with per-window metrics, Dictionary with summary statistics.
    """
    assert predictor.prediction_length == prediction_length, f"Prediction length mismatch, predictor.prediction_length: {predictor.prediction_length}, prediction_length: {prediction_length}"

    if metrics is None:
        metrics, metric_names = get_metric_objects(["MSE", "MASE", "SMAPE", "MAPE"])
    
    window_metrics = []
    total_windows = test_dataset.dataset.windows if test_dataset.dataset is not None else test_dataset.windows
    add_date_features = config.add_date_features if config is not None else False
    
    print(f"Evaluating model on {test_dataset.name} test data...")
    
    # Process each test window
    input_prediction_length = test_prediction_length if test_prediction_length is not None else prediction_length
    if input_prediction_length == 24:
        total_windows = 10
    elif input_prediction_length == 48:
        total_windows = 5
    elif input_prediction_length == 72:
        total_windows = 2
    elif input_prediction_length == 96:
        total_windows = 1
    elif input_prediction_length == 120:
        total_windows = 1
    else:
        raise ValueError(f"Invalid prediction length: {input_prediction_length}")

    for window_idx, (input_df, target_df) in tqdm(enumerate(test_dataset.get_test_data_by_prediction_length(context_length=context_length, add_date_features=add_date_features, prediction_length=input_prediction_length)), 
                                                 desc="Processing test windows"):
        # Make predictions
        data = pd.concat([input_df, target_df]).sort_index()      
        predictions = predictor.predict_longer(input_df, prediction_length=input_prediction_length, known_covariates=data[predictor.known_covariates_names] if predictor.known_covariates_names and len(predictor.known_covariates_names) > 0 else None)
        
        # Check if the predictions and target have the same length
        assert len(predictions) == len(target_df), f"Prediction length mismatch, len(predictions): {len(predictions)}, len(target_df): {len(target_df)}"
        
        # Calculate metrics
        window_result = {'window': window_idx + 1}
        
        for metric, name in zip(metrics, metric_names):
            try:
                score = metric(
                    data=data,
                    predictions=predictions,
                    prediction_length=input_prediction_length,
                    target=config.target_column,
                )
                window_result[name] = score
                print(f"Window {window_idx+1} {name}: {score:.4f}")
            except Exception as e:
                print(f"Error calculating {name} for window {window_idx+1}: {str(e)}")
                window_result[name] = np.nan
        
        window_metrics.append(window_result)

        if is_plot and config is not None:
            print(termcolor.colored(f"Plotting forecasts for window {window_idx+1}", "green"))
            plot_forecasts(
                input_df, 
                predictions, 
                target_df, 
                window_idx, 
                test_dataset.name,
                test_dataset.term,
                config,
                output_folder=config.result_folder
            )
        
        # Free memory
        del predictions, input_df, target_df, data
    
    # Create DataFrame with per-window metrics
    window_df = pd.DataFrame(window_metrics)
    
    # Calculate summary statistics
    summary = calculate_summary_stats(window_df, metric_names)
    
    return window_df, summary


def calculate_summary_stats(window_df, metric_names):
    """Calculate summary statistics from window metrics.
    
    Args:
        window_df: DataFrame with per-window metrics.
        metric_names: List of metric names.
        
    Returns:
        Dictionary with summary statistics.
    """
    summary = {
        'num_windows': len(window_df)
    }
    
    for name in metric_names:
        values = window_df[name].dropna().values
        if len(values) > 0:
            summary[f'{name}_mean'] = np.mean(values)
            summary[f'{name}_std'] = np.std(values)
        else:
            summary[f'{name}_mean'] = np.nan
            summary[f'{name}_std'] = np.nan
    
    # Print summary
    print("\n=== Summary Statistics ===")
    for name in metric_names:
        if f'{name}_mean' in summary:
            print(f"{name}: {summary[f'{name}_mean']:.4f} ± {summary[f'{name}_std']:.4f}")
    
    return summary


def save_results(summary, config, result_folder="results", config_dict=None):
    """Save evaluation results to CSV file with organized folder structure.
    
    Args:
        summary: Dictionary with summary statistics.
        config: Evaluation configuration.
        result_folder: Base folder to save the results.
        config_file_path: Path to the original config file.
        
    Returns:
        Path to the saved results file.
    """
    import shutil
    
    # Use the timestamp from the config
    timestamp = config.timestamp
    
    # Extract dataset and model info for folder naming
    dataset_name = config.dataset_name.replace('/', '_')
    model_name = config.model_name.replace(' ', '_')
    model_option = config.model_option.replace('-', '_')
    
    # Create organized folder structure
    # results/dataset_name/model_name/YYYYMMDD_HHMMSS/
    folder_path = os.path.join(
        result_folder,
        dataset_name,
        model_name,
        model_option,
        timestamp
    )
    
    os.makedirs(folder_path, exist_ok=True)
    config_path = f"{folder_path}/config.yaml"
    
    # Add dataset and model information to summary
    summary_with_info = {
        'dataset': config.dataset_name,
        'model': model_name,
        'term': config.term,
        'timestamp': timestamp,
        'regressor_types': ','.join(config.regressor_types),
        'known_covariates': ','.join(config.known_covariates_names) if config.known_covariates_names else None,
        'known_covariates_real': ','.join(config.known_covariates_real) if config.known_covariates_real else None,
        'known_covariates_cat': ','.join(config.known_covariates_cat) if config.known_covariates_cat else None,
        'static_features_cat': ','.join(config.static_features_cat) if config.static_features_cat else None,
        'static_features_real': ','.join(config.static_features_real) if config.static_features_real else None,
        'past_covariates_real': ','.join(config.past_covariates_real) if config.past_covariates_real else None,
        'past_covariates_cat': ','.join(config.past_covariates_cat) if config.past_covariates_cat else None,
        'bolt_model': config.bolt_model_path,
        'fine_tune': config.fine_tune,
        'fine_tune_steps': config.fine_tune_steps,
        'use_lora': config.use_lora,
        "aggregation_strategy_name": config.aggregation_strategy[0],
        "aggregation_strategy_config": config.aggregation_strategy[1],
        "config_path": config_path,
        **summary
    }
    
    # Save detailed summary statistics to the run-specific folder
    output_file = os.path.join(folder_path, "metrics.csv")
    pd.DataFrame([summary_with_info]).to_csv(output_file, index=False)
    
    # Also save a copy to a consolidated results file for easy comparison
    if config.add_date_features:
        print(f"Saving consolidated results to: all_results_{dataset_name}_with_date_ctx{config.context_length}_pred{config.test_prediction_length}.csv")
        consolidated_file = os.path.join(result_folder, f"all_results_{dataset_name}_with_date_ctx{config.context_length}_pred{config.test_prediction_length}.csv")
    else:
        print(f"Saving consolidated results to: all_results_{dataset_name}_ctx{config.context_length}_pred{config.test_prediction_length}.csv")
        consolidated_file = os.path.join(result_folder, f"all_results_{dataset_name}_ctx{config.context_length}_pred{config.test_prediction_length}.csv")
    
    # Check if consolidated file exists to append or create new
    if os.path.exists(consolidated_file):
        df = pd.read_csv(consolidated_file)
        df = pd.concat([df, pd.DataFrame([summary_with_info])], ignore_index=True)
    else:
        df = pd.DataFrame([summary_with_info])
    
    df.to_csv(consolidated_file, index=False)
    
    # Copy the original config file to the run folder if provided
    if config_dict is not None:
        with open(config_path, "w") as f:
            yaml.dump(config_dict, f)
        print(f" - Configuration copied to: {config_path}")
    
    print(f"Results saved to:")
    print(f" - Run specific: {output_file}")
    print(f" - Consolidated: {consolidated_file}")
    
    return output_file


def plot_forecasts(input_df, forecast_df, target_df, window_idx, dataset_name, term, config, output_folder="plots"):
    """Plot time series forecasts with input, prediction, and target data.
    
    Args:
        input_df: TimeSeriesDataFrame containing historical input data.
        forecast_df: TimeSeriesDataFrame containing forecasts/predictions.
        target_df: TimeSeriesDataFrame containing actual future values.
        window_idx: Index of the current window being evaluated.
        dataset_name: Name of the dataset being evaluated.
        config: Evaluation configuration.
        output_folder: Base folder for plots.
        
    Returns:
        Figure object with the plots.
    """
    import matplotlib.pyplot as plt
    
    # Use the timestamp from the config
    timestamp = config.timestamp
    
    # Extract dataset and model info for folder naming
    dataset_name_safe = dataset_name.replace('/', '_')
    model_name_safe = config.model_name.replace(' ', '_')
    target_column = config.target_column
    model_option = config.model_option.replace('-', '_')
    assert target_column in input_df.columns, f"Target column {target_column} not found in input_df"
    
    # Create organized folder structure for plots
    # plots/dataset_name/model_name/YYYYMMDD_HHMMSS/windows/
    folder_path = os.path.join(
        output_folder,
        dataset_name_safe,
        model_name_safe,
        model_option,
        timestamp,
        term,
        "windows"
    )
    
    os.makedirs(folder_path, exist_ok=True)
    
    # Get common item IDs from all dataframes
    common_item_ids = list(set(input_df.item_ids) & 
                          set(forecast_df.item_ids) & 
                          set(target_df.item_ids))
    
    # Take only first 12 items
    item_ids = common_item_ids[:min(12, len(common_item_ids))]
    num_items = len(item_ids)
    
    # Calculate MASE for each item
    mse_scores = {}
    for item_id in item_ids:
        
        # Calculate MASE for this item
        try:
            # Get target and forecast values for this specific item
            actual_values = target_df.loc[item_id][target_column].values
            forecast_values = forecast_df.loc[item_id]['mean'].values
            mse = np.mean((actual_values - forecast_values) ** 2)
            mse_scores[item_id] = mse
        except Exception as e:
            print(f"Error calculating MASE for item {item_id}: {str(e)}")
            mse_scores[item_id] = np.nan
    
    # Determine the layout based on number of items
    if num_items == 1:
        nrows, ncols = 1, 1
        figsize = (10, 6)
    elif num_items <= 3:
        nrows, ncols = 1, num_items
        figsize = (6*num_items, 6)
    elif num_items <= 4:
        nrows, ncols = 2, 2
        figsize = (12, 10)
    elif num_items <= 8:
        nrows, ncols = 2, 4
        figsize = (16, 10)
    else:  # 9-12 items
        nrows, ncols = 3, 4
        figsize = (16, 12)
    
    # Create figure with adaptive grid of subplots
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, squeeze=False)
    plt.subplots_adjust(hspace=0.4, wspace=0.3)
    
    # Set common style parameters according to rules
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'axes.grid': True,
        'axes.grid.which': 'both',
        'grid.linestyle': '--',
        'grid.alpha': 0.3,
        'axes.spines.top': False,
        'axes.spines.right': False,
        'xtick.direction': 'in',
        'ytick.direction': 'in',
        'lines.linewidth': 2
    })
    
    # Use color-blind safe palette
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', 
              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', 
              '#bcbd22', '#17becf']
    
    # Create subplots
    for i, item_id in enumerate(item_ids):
        # Get the appropriate subplot
        row_idx = i // ncols
        col_idx = i % ncols
        ax = axes[row_idx, col_idx]
        
        # Get data for this item
        input_ts = input_df.loc[item_id][target_column]
        forecast_ts = forecast_df.loc[item_id]['mean'] if 'mean' in forecast_df.columns else forecast_df.loc[item_id].iloc[:, 0]
        target_ts = target_df.loc[item_id][target_column]
        
        # Get quantiles for confidence intervals (10th and 90th percentiles for 80% CI)
        lower_quantile = None
        upper_quantile = None
        
        # Check if quantile columns exist in the forecast dataframe
        if '0.1' in forecast_df.columns and '0.9' in forecast_df.columns:
            lower_quantile = forecast_df.loc[item_id]['0.1']
            upper_quantile = forecast_df.loc[item_id]['0.9']
        
        # Plot historical data
        ax.plot(input_ts.index[-64:], input_ts.values[-64:], #TODO only last 64 points
                label='Historical', color=colors[0], linewidth=2)
        
        # Plot forecast
        ax.plot(forecast_ts.index, forecast_ts.values, 
                label='Forecast', color=colors[1], linewidth=2, linestyle='-')
        
        # Plot confidence interval if available
        if lower_quantile is not None and upper_quantile is not None:
            ax.fill_between(
                forecast_ts.index, 
                lower_quantile.values, 
                upper_quantile.values,
                color=colors[1], alpha=0.2, 
                label='80% Confidence'
            )
        
        # Plot target (actual values)
        ax.plot(target_ts.index, target_ts.values, 
                label='Actual', color=colors[2], linewidth=2, linestyle='-')
        
        # Set title with MASE score
        mse_value = mse_scores.get(item_id, np.nan)
        title = f"Item: {item_id} - MSE: {mse_value:.4f}" if not np.isnan(mse_value) else f"Item: {item_id}"
        ax.set_title(title)
            
        # Format x-axis for better readability
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        
        # Remove black color from x and y axes
        ax.spines['bottom'].set_color('#cccccc')
        ax.spines['left'].set_color('#cccccc')
        ax.spines['right'].set_color('#cccccc')
        ax.spines['top'].set_color('#cccccc')
        ax.tick_params(axis='x', colors='#555555')
        ax.tick_params(axis='y', colors='#555555')

        # Show x and y labels on outer plots
        if row_idx == nrows-1:  # Bottom row
            ax.set_xlabel('Time')
        if col_idx == 0:  # Leftmost column
            ax.set_ylabel('Value')
    
    # Hide unused subplots
    for i in range(num_items, nrows * ncols):
        row_idx = i // ncols
        col_idx = i % ncols
        axes[row_idx, col_idx].set_visible(False)
    
    # Add a common legend in unused space
    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', 
              bbox_to_anchor=(0.5, -0.05), ncol=4, frameon=False)
    
    # Apply tight layout to optimize use of space
    plt.tight_layout()
    
    # Save the figure to the organized folder
    plt.savefig(os.path.join(folder_path, f'window_{window_idx+1}_forecast.png'), dpi=300, bbox_inches='tight')
    plt.close(fig)
    
    return fig


def evaluate_regressor(config, config_dict):
    """Evaluate a regressor model based on configuration.
    
    Args:
        config: Evaluation configuration.
        config_dict: Dictionary with the configuration.
        
    Returns:
        Summary statistics dictionary.
    """
    # Initialize dataset wrapper
    print(f"Loading dataset: {config.dataset_name}, term: {config.term}")
    if config.dataset_name == "m5":
        gift_dataset = M5Wrapper(name=config.dataset_name)
    elif config.dataset_name == "traffic":
        gift_dataset = TrafficWrapper(name=config.dataset_name, term=config.term)
    elif config.dataset_name == "exchange":
        gift_dataset = ExchangeWrapper(name=config.dataset_name, term=config.term)
    elif config.dataset_name == "epf":
        gift_dataset = EPFWrapper(name=config.dataset_name, term=config.term)
    elif config.dataset_name == "illness":
        gift_dataset = IllnessWrapper(name=config.dataset_name, term=config.term)
    elif config.dataset_name == "nonlinear_sales_synthetic":
        gift_dataset = NonLinearSaleSyntheticWrapper(name=config.dataset_name, term=config.term)
    elif config.dataset_name == "sales_synthetic":
        gift_dataset = SaleSyntheticWrapper(name=config.dataset_name, term=config.term)
    elif config.dataset_name == "electricity_synthetic":
        gift_dataset = ElectricityWrapper(name=config.dataset_name, term=config.term)
    else:
        gift_dataset = GiftEvalWrapper(
            name=config.dataset_name, 
            term=config.term, 
            to_univariate=config.to_univariate
        )
    
    # Get training data
    train_data = gift_dataset.get_train_data(add_date_features=config.add_date_features)
    
    # Slice training data if needed
    if config.slice_start is not None:
        print(f"Slicing training data from {config.slice_start} to None")
        train_data = train_data.slice_by_timestep(config.slice_start, None)


    # Initialize with custom configuration
    print(f"Training predictor with custom regressor configuration...")
    predictor = Predictor(
        prediction_length=gift_dataset.prediction_length,
        context_length=config.context_length,
        target=config.target_column,
        known_covariates_names=config.known_covariates_names,
        known_covariates_real=config.known_covariates_real,
        known_covariates_cat=config.known_covariates_cat,
        static_features_cat=config.static_features_cat,
        static_features_real=config.static_features_real,
        past_covariates_real=config.past_covariates_real,
        past_covariates_cat=config.past_covariates_cat,
        eval_metric="MASE",
        regressor_types=config.regressor_types, 
        regressor_hyperparameters=config.regressor_config,
        aggregation_strategy=config.aggregation_strategy,
        aggregation_train_length= config.aggregation_train_length_times * gift_dataset.prediction_length, # 4/1
        bolt_model_path=config.bolt_model_path,
        random_seed=123,
        regressor_fit_time_fraction=0.5, 
        regressor_validation_fraction=0.1,
        verbosity=2
    ).fit(
        train_data, 
        time_limit=config.time_limit, 
        enable_ensemble=False, 
        fine_tune=config.fine_tune, 
        fine_tune_steps=config.fine_tune_steps, 
        use_lora=config.use_lora, 
        context_length=config.context_length,
    )
    
    # Get metric objects
    metrics, metric_names = get_metric_objects(config.metrics)
    
    # Evaluate the predictor
    _, summary = evaluate_predictor(
        predictor=predictor,
        test_dataset=gift_dataset,
        context_length=config.context_length,
        prediction_length=gift_dataset.prediction_length,
        metrics=metrics,
        metric_names=metric_names,
        is_plot=config.is_plot,
        config=config,
        test_prediction_length=config.test_prediction_length
    )
    
    # Save results
    save_results(summary, config, config.result_folder, config_dict)
    
    return summary


def main():
    """Main function to run the evaluation."""
    parser = argparse.ArgumentParser(description="Evaluate time series regressor")
    parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
    parser.add_argument("--model", type=str, required=False, default="default", help="Variant of the model to evaluate")
    parser.add_argument("--regressor_types", nargs='+', default=None, help="List of regressor types to evaluate")
    args = parser.parse_args()
    
    # Load configuration
    config_dict, config = load_config(args.config, args.model, args.regressor_types)

    print(f"--------------------------------")
    print(f"important: you are using the new evaluation function with test_prediction_length={config.test_prediction_length}")
    print(f"--------------------------------")
    
    # Run evaluation
    summary = evaluate_regressor(config, config_dict=config_dict)
    
    print("Evaluation complete!")
    return summary


if __name__ == "__main__":
    main()