import os
import pandas as pd
import numpy as np
import random
from residual_chronos.Aggregator import EnsembleAggregator
import torch
from typing import Optional, List, Union, Dict, Any
from pathlib import Path

from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor
from autogluon.timeseries.utils.features import CovariateMetadata
from autogluon.timeseries.regressor import GlobalCovariateRegressor
from autogluon.timeseries.transforms.target_scaler import LocalStandardScaler
from autogluon.timeseries.metrics import TimeSeriesScorer
from autogluon.timeseries.utils.forecast import make_future_data_frame

from residual_chronos.Regressor import CrossSectionalRegressor

class Predictor:
    """
    A class that implements a three-step forecasting approach:
    1. Apply target scaling using LocalStandardScaler
    2. Use GlobalCovariateRegressor to remove covariate effects and get residuals
    3. Apply Chronos-Bolt model to predict residuals
    4. Inverse transform the predictions to get final forecasts
    
    This implementation follows the process in test_res3.py exactly.
    
    Example usage:
    -------------
    # Initialize model
    model = Predictor(
        prediction_length=8,
        target="sales",
        known_covariates_real=["price"],
        known_covariates_cat=["promotion"],
        static_features_real=["store_size"],
        static_features_cat=["store_type"],
        regressor_types=["XGB", "RF", "CAT"],  # Use multiple regressors
        regressor_hyperparameters={
            "XGB": {"learning_rate": 0.1},
            "RF": {"n_estimators": 100},
            "CAT": {"iterations": 100}
        },
        eval_metric="MAE"
    )
    
    # Fit model
    model.fit(train_data)
    
    # Generate forecasts
    forecasts = model.predict(context_data, known_covariates=future_covariates)
    """
    
    # Mapping from TimeSeriesPredictor metrics to regressor metrics
    _METRIC_MAPPING = {
        # Map public metric names to regressor metric names
        "MAE": "mean_absolute_error",
        "MSE": "mean_squared_error",
        "RMSE": "root_mean_squared_error",
        "MAPE": "mean_absolute_percentage_error",
        "SMAPE": "symmetric_mean_absolute_percentage_error",
        "MASE": "mean_absolute_scaled_error",
        "WQL": "weighted_quantile_loss",
        "SQL": "scaled_quantile_loss",
        # Add identity mappings for cases where users pass the full name
        "mean_absolute_error": "mean_absolute_error",
        "mean_squared_error": "mean_squared_error",
        "root_mean_squared_error": "root_mean_squared_error",
        "mean_absolute_percentage_error": "mean_absolute_percentage_error",
        "symmetric_mean_absolute_percentage_error": "symmetric_mean_absolute_percentage_error",
        "mean_absolute_scaled_error": "mean_absolute_scaled_error",
        "weighted_quantile_loss": "weighted_quantile_loss",
        "scaled_quantile_loss": "scaled_quantile_loss",
    }
    
    def __init__(
        self,
        prediction_length: int = 1,
        target: str = "target",
        known_covariates_names: Optional[List[str]] = None,
        known_covariates_real: Optional[List[str]] = None,
        known_covariates_cat: Optional[List[str]] = None,
        static_features_cat: Optional[List[str]] = None,
        static_features_real: Optional[List[str]] = None,
        past_covariates_real: Optional[List[str]] = None,
        past_covariates_cat: Optional[List[str]] = None,
        eval_metric: str = "MAE",
        regressor_types: Union[str, List[str]] = None,
        regressor_hyperparameters: Optional[Dict[str, Dict[str, Any]]] = None,
        bolt_model_path: str = "bolt_small",
        random_seed: int = 123,
        verbosity: int = 2,
        regressor_fit_time_fraction: float = 0.5,
        regressor_validation_fraction: float = 0.1,
        aggregation_strategy: Union[str, EnsembleAggregator] = 'equal',
        aggregation_train_length: Optional[int] = None,
        lora_cfg: Optional[Dict[str, Any]] = None,
        context_length: Optional[int] = None,
    ):
        """
        Initialize the Predictor model.
        
        Parameters
        ----------
        prediction_length : int
            Length of the prediction horizon
        target : str
            Name of the target column
        known_covariates_names : List[str], optional
            Names of all covariates that are known in advance (future values available at prediction time)
            DEPRECATED: Use known_covariates_real and known_covariates_cat instead
        known_covariates_real : List[str], optional
            Names of numerical covariates that are known in advance (future values available at prediction time)
        known_covariates_cat : List[str], optional
            Names of categorical covariates that are known in advance (future values available at prediction time)
        static_features_cat : List[str], optional
            Names of categorical static features (features that remain constant over time)
        static_features_real : List[str], optional
            Names of numerical static features (features that remain constant over time)
        past_covariates_real : List[str], optional
            Names of numerical past covariates (features only known in the past)
        past_covariates_cat : List[str], optional
            Names of categorical past covariates (features only known in the past)
        eval_metric : str
            Evaluation metric for model selection (e.g., 'MAE', 'MSE', 'MAPE')
        regressor_types : Union[str, List[str]]
            Type(s) of regressor to use for modeling covariate effects (e.g., 'XGB', 'CAT', 'RF')
            Can be a single string or a list of strings
        regressor_hyperparameters : Dict[str, Dict[str, Any]], optional
            Hyperparameters for each regressor type, keyed by regressor name
        bolt_model_path : str
            Path to the Chronos-Bolt model to use (e.g., 'bolt_small', 'bolt_mini')
        random_seed : int
            Random seed for reproducibility
        verbosity : int
            Verbosity level for logging
        regressor_fit_time_fraction : float
            Fraction of training time allocated to fitting the regressor
        regressor_validation_fraction : float
            Fraction of data used for validation in the regressor
        aggregation_train_length : Optional[int]
            Number of data points to use for training the aggregator    
        context_length : Optional[int]
            Number of data points to use for context length
        """
        self.prediction_length = prediction_length
        self.quantile_levels = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        self.target = target
        
        # Handle deprecated known_covariates_names parameter
        if known_covariates_names is not None:
            if known_covariates_real is not None or known_covariates_cat is not None:
                raise ValueError(
                    "You provided both known_covariates_names and known_covariates_real/known_covariates_cat. "
                    "Please use only known_covariates_real and known_covariates_cat."
                )
            print("WARNING: known_covariates_names is deprecated. Please use known_covariates_real and known_covariates_cat instead.")
            # Assume all are real-valued as was the previous default
            self.known_covariates_real = known_covariates_names.copy() if known_covariates_names is not None else []
            self.known_covariates_cat = []
        else:
            self.known_covariates_real = known_covariates_real if known_covariates_real is not None else []
            self.known_covariates_cat = known_covariates_cat if known_covariates_cat is not None else []
        
        # For backward compatibility
        self.known_covariates_names = self.known_covariates_real + self.known_covariates_cat
        
        self.static_features_cat = static_features_cat if static_features_cat is not None else []
        self.static_features_real = static_features_real if static_features_real is not None else []
        self.past_covariates_real = past_covariates_real if past_covariates_real is not None else []
        self.past_covariates_cat = past_covariates_cat if past_covariates_cat is not None else []
        self.eval_metric = eval_metric
        
        # Convert regressor_types to a list if it's a string
        if isinstance(regressor_types, str):
            self.regressor_types = [regressor_types]
        else:
            self.regressor_types = regressor_types
            
        self.regressor_hyperparameters = regressor_hyperparameters or {}
        self.bolt_model_path = bolt_model_path
        self.random_seed = random_seed
        self.verbosity = verbosity
        self.regressor_fit_time_fraction = regressor_fit_time_fraction
        self.regressor_validation_fraction = regressor_validation_fraction
        self.aggregation_strategy = aggregation_strategy
        self.aggregation_train_length = aggregation_train_length
        self.context_length = context_length
        # Map the eval_metric to regressor metric
        self.regressor_eval_metric = self._map_metric(eval_metric)

        # Set default LoRA configuration
        # default_lora_cfg = {
        #     'r': 16,
        #     'lora_alpha': 32,
        #     'lora_dropout': 0.05,
        #     'target_modules': ["q", "v", "k", "o", "wi", "wo"],  # T5 names
        # }
        default_lora_cfg = {
            'r': 8,
            'lora_alpha': 16,
            'lora_dropout': 0.05,
            'target_modules': ["q", "k", "v", "o"],
            # 'bias': "none",
            # 'task_type': "SEQ_2_SEQ_LM"
        }
        self.lora_cfg = lora_cfg if lora_cfg is not None else default_lora_cfg
        
        # Set random seed
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(random_seed)
        
        # Initialize target scaler
        self.target_scaler = LocalStandardScaler(target=self.target)
        
        # Initialize covariate regressor
        self._init_covariate_regressor()
        
        # Initialize residual predictor
        self._init_residual_predictor()
        
        # Flag to track if the model has been fit
        self._is_fit = False
    
    def _map_metric(self, metric: str) -> str:
        """Map a public-facing metric name to the corresponding regressor metric name."""
        if metric in self._METRIC_MAPPING:
            return self._METRIC_MAPPING[metric]
        else:
            # Default to mean_absolute_error if unknown metric
            print(f"Warning: Unknown metric '{metric}'. Using 'mean_absolute_error' for regressor.")
            return "mean_absolute_error"
    
    def _init_covariate_regressor(self):
        """Initialize the covariate regressor."""
        # Define covariate metadata
        covariate_metadata = CovariateMetadata(
            static_features_cat=self.static_features_cat, 
            static_features_real=self.static_features_real, 
            known_covariates_real=self.known_covariates_real, 
            known_covariates_cat=self.known_covariates_cat, 
            past_covariates_real=self.past_covariates_real, 
            past_covariates_cat=self.past_covariates_cat
        )
        
        # Check whether to use a single regressor or ensemble
        if self.regressor_types is None or len(self.regressor_types) == 0:
            self.covariate_regressor = None 
        elif len(self.regressor_types) == 1:
            # Use a single GlobalCovariateRegressor
            self.covariate_regressor = CrossSectionalRegressor(
                model_names=[self.regressor_types[0]],
                target=self.target,
                covariate_metadata=covariate_metadata,
                include_static_features=True,
                include_item_id=True,
                models_hyperparameters=self.regressor_hyperparameters,
                fit_time_fraction=self.regressor_fit_time_fraction,
                validation_fraction=self.regressor_validation_fraction,
                eval_metric=self.regressor_eval_metric,
                random_seed=self.random_seed,
                aggregation_strategy='equal',
                prediction_length=self.prediction_length,
                verbosity=self.verbosity
            )
        else:
            # Use the ensemble regressor with multiple models
            self.covariate_regressor = CrossSectionalRegressor(
                model_names=self.regressor_types,
                target=self.target,
                covariate_metadata=covariate_metadata,
                include_static_features=True,
                include_item_id=True,
                models_hyperparameters=self.regressor_hyperparameters,
                fit_time_fraction=self.regressor_fit_time_fraction,
                validation_fraction=self.regressor_validation_fraction,
                eval_metric=self.regressor_eval_metric,
                aggregation_strategy=self.aggregation_strategy,
                random_seed=self.random_seed,
                prediction_length=self.prediction_length,
                aggregation_train_length=self.aggregation_train_length,
                verbosity=self.verbosity
            )
    
    def _init_residual_predictor(self):
        """Initialize the residual predictor."""
        if self.bolt_model_path is not None and os.path.exists(self.bolt_model_path):
            print(f"Loading residual predictor from {self.bolt_model_path}")
            self.residual_predictor = TimeSeriesPredictor.load(self.bolt_model_path, require_version_match=False)
        else:
            print(f"residual predictor not found, creating new one")
            self.residual_predictor = TimeSeriesPredictor(
                prediction_length=self.prediction_length,
                target=self.target,
                eval_metric=self.eval_metric,
                verbosity=self.verbosity
            )

    def _validate_features(self, data: TimeSeriesDataFrame) -> None:
        """
        Validate that all configured features exist in the data.
        
        Parameters
        ----------
        data : TimeSeriesDataFrame
            The data to validate
            
        Raises
        ------
        KeyError
            If any configured feature is missing from the data
        """
        # Check known covariates (real and categorical)
        for covariate in self.known_covariates_real + self.known_covariates_cat:
            if covariate not in data.columns:
                raise KeyError(f"Known covariate '{covariate}' not found in data columns: {list(data.columns)}")
        
        # Check past covariates (real and categorical)
        for covariate in self.past_covariates_real + self.past_covariates_cat:
            if covariate not in data.columns:
                raise KeyError(f"Past covariate '{covariate}' not found in data columns: {list(data.columns)}")
        
        # Check static features (real and categorical)
        if self.static_features_real or self.static_features_cat:
            if data.static_features is None:
                raise KeyError(f"Static features were configured but data has no static_features attribute")
            
            for feature in self.static_features_real + self.static_features_cat:
                if feature not in data.static_features.columns:
                    raise KeyError(f"Static feature '{feature}' not found in static_features columns: {list(data.static_features.columns)}")

    def fit(
        self,
        train_data: Union[TimeSeriesDataFrame, pd.DataFrame, Path, str],
        time_limit: Optional[int] = None,
        enable_ensemble: bool = False,
        reset_model: bool = False,
        fine_tune: bool = False,
        use_lora: bool = False,
        **kwargs
    ) -> "Predictor":
        """
        Fit the Predictor model.
        
        Parameters
        ----------
        train_data : TimeSeriesDataFrame
            Training data in TimeSeriesDataFrame format. This should contain:
            - The target column (e.g., 'unit_sales')
            - Any known covariates (e.g., 'price', 'promotion')
            - Any past covariates
            - Static features (can be in the data or as separate static_features attribute)
            - Item IDs and timestamps
            
            Example:
            ```
                                target  price  promotion
            item_id timestamp                           
            A       2022-01-01   100.0   9.99          0
                    2022-01-02   120.0   9.99          0
                    2022-01-03   150.0   7.99          1
            B       2022-01-01    50.0  14.99          0
                    2022-01-02    55.0  14.99          0
                    2022-01-03    70.0  12.99          1
            ```
            
        time_limit : int, optional
            Time limit in seconds for training
        enable_ensemble : bool
            Whether to enable ensemble modeling
        reset_model : bool
            If True, reinitialize the model components before fitting
        **kwargs :
            Additional arguments for internal components
            
        Returns
        -------
        self : Predictor
            The fitted model instance
        """
        # Convert to TimeSeriesDataFrame if needed
        if not isinstance(train_data, TimeSeriesDataFrame):
            train_data = TimeSeriesDataFrame(train_data)
        
        # Validate all configured features exist in the data
        self._validate_features(train_data)
        
        # Create a copy to avoid modifying the input data
        train_data_copy = train_data.copy()
        
        # Reinitialize components if requested
        if reset_model:
            self.target_scaler = LocalStandardScaler(target=self.target)
            self._init_covariate_regressor()
            self._init_residual_predictor()
        
        # STEP 1: Apply LocalStandardScaler to the target
        if self.target_scaler is not None:
            scaled_data = self.target_scaler.fit_transform(train_data_copy)
        else:
            scaled_data = train_data_copy
        
        # STEP 2: Apply the regressor to get residuals
        regressor_time_limit = int(time_limit * self.regressor_fit_time_fraction) if time_limit else 60
        if self.covariate_regressor is not None:
            residuals = self.covariate_regressor.fit_transform(scaled_data, time_limit=regressor_time_limit, context_length=self.context_length)
        else:
            residuals = scaled_data
        
        # STEP 3: Train Chronos-Bolt on residuals
        predictor_time_limit = time_limit - regressor_time_limit if time_limit else 60
        if self.residual_predictor is not None and self.bolt_model_path is not None:
            hyperparameters = {
                "Chronos": {
                    "model_path": self.bolt_model_path,
                    "fine_tune": fine_tune,
                    "fine_tune_steps": kwargs.get("fine_tune_steps", 1000),
                    "eval_during_fine_tune": kwargs.get("eval_during_fine_tune", True),
                },
            }
            
            # Only add lora configuration if it's present in kwargs and only add in Chronos
            if use_lora:
                hyperparameters["Chronos"]["lora"] = self.lora_cfg
                hyperparameters["Chronos"]["fine_tune_lr"] = 3e-4
            
            if not self.residual_predictor._learner.is_fit:
                self.residual_predictor = self.residual_predictor.fit(
                    residuals,
                    hyperparameters=hyperparameters,
                    enable_ensemble=enable_ensemble,
                    time_limit=predictor_time_limit,
                    random_seed=self.random_seed,
                )
        
        self._is_fit = True
        return self
    
    def _create_zero_predictions(self, context_data: TimeSeriesDataFrame) -> TimeSeriesDataFrame:
        """Create a TimeSeriesDataFrame with zero predictions for all items in the context data.
        
        Parameters
        ----------
        context_data : TimeSeriesDataFrame
            The context data containing the items and their historical data
            
        Returns
        -------
        TimeSeriesDataFrame
            A TimeSeriesDataFrame with zero predictions for all quantiles
        """
        future_index = pd.MultiIndex.from_frame(
            make_future_data_frame(context_data, prediction_length=self.prediction_length, freq=context_data.freq)
        )

        df = pd.DataFrame(
            np.zeros((len(future_index), len(self.quantile_levels) + 1)),
            columns=["mean"] + [str(q) for q in self.quantile_levels],
            index=future_index,
        )

        return TimeSeriesDataFrame(df)

    def predict(
        self,
        data: Union[TimeSeriesDataFrame, pd.DataFrame, Path, str],
        known_covariates: Optional[Union[TimeSeriesDataFrame, pd.DataFrame, Path, str]] = None,
        static_features: Optional[Union[pd.DataFrame, Path, str]] = None,
        **kwargs
    ) -> TimeSeriesDataFrame:
        """
        Generate forecasts for the given data.
        
        Parameters
        ----------
        data : TimeSeriesDataFrame
            Historical time series data for forecasting (context). This should contain:
            - The target column (e.g., 'unit_sales')
            - Any past covariates used during training
            - Item IDs and timestamps
            
            Example:
            ```
                                target  price  promotion
            item_id timestamp                           
            A       2022-01-01   100.0   9.99          0
                    2022-01-02   120.0   9.99          0
            B       2022-01-01    50.0  14.99          0
                    2022-01-02    55.0  14.99          0
            ```
            
        known_covariates : TimeSeriesDataFrame, optional
            Known covariates for the forecast horizon. This should contain:
            - The covariate columns specified in known_covariates_real and known_covariates_cat
            - Item IDs and timestamps for the forecast horizon
            
            Example (for prediction_length=2):
            ```
                                price  promotion  holiday
            item_id timestamp                    
            A       2022-01-03   7.99          1     True
                    2022-01-04   7.99          1    False
            B       2022-01-03  12.99          1     True
                    2022-01-04  12.99          0    False
            ```
        
        static_features : pd.DataFrame, optional
            Static features for the forecast horizon. This should contain:
            - The static feature columns specified in static_features_cat and static_features_real
            - Item IDs as index
            
            Example:
            ```
                     region  store_size
            item_id                  
            A       northeast     10000
            B       southwest      5000
            ```
            
        **kwargs :
            Additional arguments for prediction
            
        Returns
        -------
        TimeSeriesDataFrame
            DataFrame containing the forecasts, with columns for mean and quantile predictions
            
            Example:
            ```
                                mean    0.1    0.5    0.9
            item_id timestamp                           
            A       2022-01-03  140.0  120.0  140.0  160.0
                    2022-01-04  145.0  125.0  145.0  165.0
            B       2022-01-03   65.0   55.0   65.0   75.0
                    2022-01-04   60.0   50.0   60.0   70.0
            ```
        """
        if not self._is_fit:
            raise RuntimeError("Model has not been fit yet. Call fit() first.")
        
        # Convert to TimeSeriesDataFrame if needed
        if not isinstance(data, TimeSeriesDataFrame):
            data = TimeSeriesDataFrame(data)
        if known_covariates is not None and not isinstance(known_covariates, TimeSeriesDataFrame):
            known_covariates = TimeSeriesDataFrame(known_covariates)
        
        # If static_features is None but data has static_features, use those
        if static_features is None and hasattr(data, 'static_features') and data.static_features is not None:
            static_features = data.static_features
            
        # Check if provided static_features have the correct columns
        if static_features is not None:
            missing_cat = set(self.static_features_cat) - set(static_features.columns)
            missing_real = set(self.static_features_real) - set(static_features.columns)
            
            if missing_cat or missing_real:
                missing_cols = list(missing_cat) + list(missing_real)
                raise ValueError(f"Static features are missing columns: {missing_cols}")
        
        # Check if provided known_covariates have the correct columns
        if known_covariates is not None:
            missing_cat = set(self.known_covariates_cat) - set(known_covariates.columns)
            missing_real = set(self.known_covariates_real) - set(known_covariates.columns)
            
            if missing_cat or missing_real:
                missing_cols = list(missing_cat) + list(missing_real)
                raise ValueError(f"Known covariates are missing columns: {missing_cols}")
        
        # Create a copy to avoid modifying the input data
        context_data = data.copy()
        
        # Following the exact steps from test_res3.py:
        
        # 1. Scale the context data
        if self.target_scaler is not None:
            scaled_context = self.target_scaler.fit_transform(context_data)
        else:
            scaled_context = context_data
        
        # 2. Get residuals from the context data
        if self.covariate_regressor is not None:
            # TODO do we need to fit again here? we only transformer here, better not fit again
            context_residuals = self.covariate_regressor.fit_transform(scaled_context, context_length=self.context_length)
        else:
            context_residuals = scaled_context
        
        # 3. Predict residuals using Chronos-Bolt
        if self.residual_predictor is not None and self.bolt_model_path is not None:
            print(f"predicting residuals using residual predictor")
            residual_predictions = self.residual_predictor.predict(context_residuals, random_seed=self.random_seed) # TODO: only prediction t:t+prediction_length
            context_residuals.to_csv("/home/magics/hdd/sky_ws/residual_ws/tests/hopformer/data/context_residuals.csv")
            residual_predictions.to_csv("/home/magics/hdd/sky_ws/residual_ws/tests/hopformer/data/residual_predictions.csv")
        else:
            print(f"no residual predictor, creating zero predictions")
            # Create zero predictions using the helper method
            residual_predictions = self._create_zero_predictions(data)
        
        # 4. Inverse transform with the covariate regressor to add back the covariate effects
        # if known_covariates is None:
        #     raise ValueError("known_covariates must be provided for prediction")
        
        # Inverse transform the residual predictions
        if self.covariate_regressor is not None:
            final_predictions = self.covariate_regressor.inverse_transform(
                predictions=residual_predictions, # add back the covariate effects
                known_covariates=known_covariates, # used for prediction 
                static_features=static_features,
                context_data=scaled_context,
            )
        else:
            # If there's no covariate regressor, use the residual predictions (which are 0)
            final_predictions = residual_predictions
        
        # 5. Inverse scale the predictions
        if self.target_scaler is not None:
            final_predictions = self.target_scaler.inverse_transform(final_predictions)
        
        return final_predictions
    
    def evaluate(
        self,
        data: Union[TimeSeriesDataFrame, pd.DataFrame, Path, str],
        metrics: Optional[List[Union[str, TimeSeriesScorer]]] = None,
        fit_model: bool = False,
        **kwargs
    ) -> Dict[str, float]:
        """
        Evaluate the model on test data.
        
        Parameters
        ----------
        data : TimeSeriesDataFrame
            Test data containing both inputs and ground truth values. This should contain:
            - The target column for the entire evaluation period
            - All covariates used during training
            - Item IDs and timestamps
            
            Example (for prediction_length=2):
            ```
                                target  price  promotion
            item_id timestamp                           
            A       2022-01-01   100.0   9.99          0  # ← These are context 
                    2022-01-02   120.0   9.99          0  # ← (used for prediction)
                    2022-01-03   140.0   7.99          1  # ← These are ground truth
                    2022-01-04   145.0   7.99          1  # ← (used for evaluation)
            B       2022-01-01    50.0  14.99          0
                    2022-01-02    55.0  14.99          0
                    2022-01-03    65.0  12.99          1
                    2022-01-04    60.0  12.99          0
            ```
            
        metrics : List[Union[str, TimeSeriesScorer]], optional
            List of metrics to evaluate
            
        Returns
        -------
        Dict[str, float]
            Dictionary of metric names and scores
        """
        # Convert to TimeSeriesDataFrame if needed
        if not isinstance(data, TimeSeriesDataFrame):
            data = TimeSeriesDataFrame(data)
        
        # Split data into context and ground truth
        context = data.slice_by_timestep(None, -self.prediction_length)
        known_covariates = data.slice_by_timestep(-self.prediction_length, None)
        static_features = data.slice_by_timestep(-self.prediction_length, None).static_features
        
        if fit_model:
            self.fit(context, time_limit=kwargs.get("time_limit", 60), fine_tune=kwargs.get("fine_tune", False), enable_ensemble=kwargs.get("enable_ensemble", False))

        if not self._is_fit:
            raise RuntimeError("Model has not been fit yet. Call fit() first.")

        # Generate predictions using the predict method
        predictions = self.predict(context, known_covariates=known_covariates, static_features=static_features)
        
        # If no metrics are provided, use MASE and SMAPE
        if metrics is None:
            from autogluon.timeseries.metrics import MASE, SMAPE
            metrics = [MASE(), SMAPE()]
            metric_names = ["MASE", "SMAPE"]
        else:
            metric_names = [str(metric) for metric in metrics]
        
        # Calculate metrics
        results = {}
        for metric, name in zip(metrics, metric_names):
            score = metric(
                data=data, 
                predictions=predictions, 
                prediction_length=self.prediction_length, 
                target=self.target
            )
            results[name] = score
        
        return results
    
    def save(self, path: Union[str, Path]) -> None:
        """
        Save the model components to disk.
        
        Parameters
        ----------
        path : str or Path
            Directory path to save the model
        """
        import os
        import pickle
        
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)
        
        # Save each component separately
        with open(path / "target_scaler.pkl", "wb") as f:
            pickle.dump(self.target_scaler, f)
            
        with open(path / "covariate_regressor.pkl", "wb") as f:
            pickle.dump(self.covariate_regressor, f)
        
        # Save the residual predictor using its own method
        if self.residual_predictor is not None:
            self.residual_predictor.save(path / "residual_predictor")
        
        # Save configuration
        config = {
            "prediction_length": self.prediction_length,
            "target": self.target,
            "known_covariates_real": self.known_covariates_real,
            "known_covariates_cat": self.known_covariates_cat,
            "static_features_cat": self.static_features_cat,
            "static_features_real": self.static_features_real,
            "past_covariates_real": self.past_covariates_real,
            "past_covariates_cat": self.past_covariates_cat,
            "eval_metric": self.eval_metric,
            "regressor_types": self.regressor_types,
            "regressor_hyperparameters": self.regressor_hyperparameters,
            "bolt_model_path": self.bolt_model_path,
            "random_seed": self.random_seed,
            "verbosity": self.verbosity,
            "regressor_fit_time_fraction": self.regressor_fit_time_fraction,
            "regressor_validation_fraction": self.regressor_validation_fraction,
            "_is_fit": self._is_fit
        }
        
        with open(path / "config.pkl", "wb") as f:
            pickle.dump(config, f)
    
    @classmethod
    def load(cls, path: Union[str, Path]) -> "Predictor":
        """
        Load a saved model from disk.
        
        Parameters
        ----------
        path : str or Path
            Directory path where the model was saved
            
        Returns
        -------
        Predictor
            The loaded model
        """
        import pickle
        
        path = Path(path)
        
        # Load configuration
        with open(path / "config.pkl", "rb") as f:
            config = pickle.load(f)
        
        # Handle legacy single regressor_type config
        if "regressor_type" in config and "regressor_types" not in config:
            config["regressor_types"] = [config.pop("regressor_type")]
        
        # Extract _is_fit flag and remove it from config before passing to __init__
        is_fit = config.pop("_is_fit", False)
        
        # Create a new instance with the loaded configuration
        instance = cls(**config)
        instance._is_fit = is_fit
        
        # Load each component
        with open(path / "target_scaler.pkl", "rb") as f:
            instance.target_scaler = pickle.load(f)
        
        with open(path / "covariate_regressor.pkl", "rb") as f:
            instance.covariate_regressor = pickle.load(f)
        
        # Load the residual predictor using its own method
        instance.residual_predictor = TimeSeriesPredictor.load(path / "residual_predictor")
        
        return instance 

    def predict_longer(
        self,
        data: Union[TimeSeriesDataFrame, pd.DataFrame, Path, str],
        prediction_length: int = None,
        known_covariates: Optional[Union[TimeSeriesDataFrame, pd.DataFrame, Path, str]] = None,
        static_features: Optional[Union[pd.DataFrame, Path, str]] = None,
        **kwargs
    ) -> TimeSeriesDataFrame:
        """
        Generate forecasts for a longer horizon than the model was trained for.
        
        Parameters
        ----------
        data : TimeSeriesDataFrame
            Historical time series data for forecasting (context).
        prediction_length : int
            The desired total prediction length, must be >= self.prediction_length.
        known_covariates : TimeSeriesDataFrame, optional
            Known covariates for the entire forecast horizon.
        static_features : pd.DataFrame, optional
            Static features for the forecast.
        **kwargs :
            Additional arguments passed to the predict method.
            
        Returns
        -------
        TimeSeriesDataFrame
            DataFrame containing the forecasts for the entire prediction horizon.
        """
        if prediction_length is None:
            prediction_length = self.prediction_length
        
        if prediction_length <= self.prediction_length:
            return self.predict(data, known_covariates=known_covariates, static_features=static_features, **kwargs)
        
        # Convert inputs to TimeSeriesDataFrame if needed
        if not isinstance(data, TimeSeriesDataFrame):
            data = TimeSeriesDataFrame(data)
        
        # Initialize context and results
        context = data.copy()
        all_predictions = []
        steps_predicted = 0
        
        # Predict iteratively until we reach desired length
        while steps_predicted < prediction_length:
            # Get the slice of known covariates for this iteration if provided
            curr_covariates = None
            if known_covariates is not None:
                remaining_length = prediction_length - steps_predicted
                end_point = -remaining_length + self.prediction_length
                end_point = None if end_point >= 0 else end_point
                # curr_covariates = known_covariates.slice_by_timestep(steps_predicted, steps_predicted + self.prediction_length)
                curr_covariates = known_covariates.slice_by_timestep(None, end_point)

            # Make prediction with current context
            predictions = self.predict(context, known_covariates=curr_covariates, static_features=static_features, **kwargs)
            
            # Add predictions to results
            all_predictions.append(predictions)
            steps_predicted += self.prediction_length
            
            # Update context for next iteration
            next_context = predictions.copy()
            next_context.rename(columns={'mean': self.target}, inplace=True)
            
            # Add known covariates to next context if available
            if curr_covariates is not None:
                for col in curr_covariates.columns:
                    if col != self.target:
                        next_context[col] = curr_covariates[col]
            
            # Append to context
            context = pd.concat([context, next_context])
            context = context.sort_index()
        
        # Combine all predictions and truncate to exactly prediction_length
        final_predictions = pd.concat(all_predictions)
        return final_predictions.slice_by_timestep(0, prediction_length).sort_index()