import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional, Union, Tuple
import reprlib
import termcolor
import pickle
import os
from pathlib import Path
import structlog

from autogluon.tabular.register import ag_model_register as tabular_ag_model_register
from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor
from autogluon.timeseries.dataset.ts_dataframe import ITEMID, TimeSeriesDataFrame
from autogluon.timeseries.models.presets import MODEL_TYPES
from autogluon.timeseries.regressor import GlobalCovariateRegressor
from autogluon.timeseries.utils.features import CovariateMetadata
from autogluon.timeseries.utils.forecast import make_future_data_frame
from residual_chronos.Aggregator import (
    EnsembleAggregator, 
    EqualWeightAggregator, 
    PerformanceWeightAggregator,
    AdaptiveAggregator,
    SPAAggregator,
    SingleBestAggregator,
    LinearAggregator
)
from utils.agg_utils import plot_all

# Configure structlog
structlog.configure(
    processors=[
        structlog.processors.add_log_level,
        structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M:%S"),
        structlog.dev.ConsoleRenderer()
    ],
    logger_factory=structlog.stdlib.LoggerFactory(),
)

# Create logger instance
logger = structlog.get_logger(__name__)


class TimeSeriesRegressor(TimeSeriesPredictor):
    """
    Enhanced TimeSeriesPredictor with additional functionality for model tracking
    and efficient prediction over arbitrary windows.
    """
    
    @property
    def best_model_name(self):
        return self._learner.load_trainer().get_model_best()

    @property
    def _cached_predictions_path(self) -> Path:
        return Path(self.path) / "cached_predictions"

    def __init__(
        self,
        model_name: str,
        model_hyperparameters: Optional[Dict[str, Any]] = None,
        **kwargs
    ):
        """
        Initialize a TimeSeriesRegressor.
        
        Args:
            model_name: Name of the underlying model (e.g., "AutoARIMA", "DeepAR")
            **kwargs: Additional arguments to pass to TimeSeriesPredictor
        """
        super().__init__(**kwargs)
        self.model_name = model_name
        self.model_hyperparameters = model_hyperparameters

    def predict_all(self, data: TimeSeriesDataFrame, horizon: Optional[int]=None, min_ctx: int=5, window_size: Optional[int]=None, context_length: Optional[int]=None) -> TimeSeriesDataFrame:
        """
        Get fitted values by step-by-step prediction with configurable horizon.
        Uses a forward loop and negative indexing for slicing.
        
        Args:
            data: TimeSeriesDataFrame for training
            horizon: Forecast horizon (default=1) - number of steps to predict at once
            min_ctx: Minimum context window size (default=5)
            window_size: Window size for prediction (default=None, meaning predict all possible)
            context_length: Context length for prediction (default=None, meaning use the context length of the model)
        Returns:
            TimeSeriesDataFrame of fitted values
        """
        horizon = self.prediction_length if horizon is None else horizon
        
        # Validate inputs
        assert horizon <= self.prediction_length, "Horizon must be less than or equal to the prediction length"
        assert data.num_timesteps_per_item().max() >= min_ctx, "Train data must have at least min_ctx timestamps per item"
        if horizon != self.prediction_length:
            print(termcolor.colored(f"Horizon is not equal to prediction length, using prediction length: {self.prediction_length}", "yellow"))
        
        # Get the maximum length across all time series
        max_len = int(data.num_timesteps_per_item().max())
        
        # If window_size is not provided, predict as much as possible after min_ctx
        if window_size is None:
            window_size = max_len - min_ctx
            assert window_size > 0, "window_size too small, need at least min_ctx points before window"
        else:
            # Ensure window_size is valid
            assert min_ctx <= max_len - window_size, "window_size too large, need at least min_ctx points before window"
        
        # Handle edge case where window_size < horizon
        if window_size <= horizon:
            start_t = -(window_size + context_length) if context_length else None
            slice_t = data.slice_by_timestep(start_t, -window_size)
            forecast = self.predict(slice_t, model=self.best_model_name)
            return pd.concat([slice_t, forecast.slice_by_timestep(0, window_size)]).sort_index()
        
        # Initialize with the baseline data (not including predictions)
        initial_slice = data.slice_by_timestep(None, -window_size)
        predictions_ls = [initial_slice]
        
        # Create a list of points to predict from
        # This ensures we cover the entire window, even if window_size is not divisible by horizon
        prediction_points = list(range(horizon, window_size, horizon)) #TODO why is this horizon?
        
        # Make sure we include the last point if it's not already covered
        if (window_size - prediction_points[-1] if prediction_points else window_size) > 0:
            # do the prediction for the last point
            start_t = -(window_size + context_length) if context_length else None
            slice_t = data.slice_by_timestep(start_t, -window_size)
            forecast = self.predict(slice_t, model=self.best_model_name) # only one model is used
            predictions_ls.append(forecast.slice_by_timestep(None,window_size - prediction_points[-1]))
        
        # Loop through each prediction point
        for t in prediction_points:
            # Get data up to this point using negative indexing
            start_t = -(t + context_length) if context_length else None
            slice_t = data.slice_by_timestep(start_t, -t)
            forecast = self.predict(slice_t, model=self.best_model_name) # only one model is used
            predictions_ls.append(forecast.slice_by_timestep(None,horizon))
        
        # Combine all predictions and sort by index
        return pd.concat(predictions_ls).sort_index()

    def fit_and_predict(self, data: TimeSeriesDataFrame, time_limit: Optional[int], 
                        window_size: Optional[int]=None, random_seed: Optional[int]=None, 
                        horizon: Optional[int]=None, context_length: Optional[int]=None,
                        min_ctx: int=5, use_cache: bool=True) -> pd.Series:
        """
        Fit the model and predict the data.
        """
        if not self._learner.is_fit:
            self.fit(
                data,
                hyperparameters={
                    self.model_name: self.model_hyperparameters
                }, 
                random_seed=random_seed,
                time_limit=time_limit,
                tuning_data=data
            )
        
        # check if cached predictions are available
        dataset_hash = self._compute_dataset_hash(data, horizon, window_size, min_ctx, context_length)
        if use_cache:
            cached_predictions = self._get_cached_pred(dataset_hash)
            if cached_predictions is not None:
                print(termcolor.colored(f"Using cached predictions for {self.model_name}", "green"))
                assert data.index.equals(cached_predictions.index), f"Index of data and cached predictions must be the same. Deleting the cached predictions \
                    {self._cached_predictions_path / f'{dataset_hash}.pkl'} and predicting from scratch."
                return cached_predictions
        
        # predict the data
        pred_data = self.predict_all(data, horizon=horizon, window_size=window_size, min_ctx=min_ctx, context_length=context_length)

        # reindex to the original item_id order
        original_item_id_order = data.item_ids
        pred_data = pred_data.reindex(original_item_id_order, level=ITEMID)

        # check if the index of data and pred_data are the same
        assert data.index.equals(pred_data.index), "Index of data and pred_data must be the same, pred_data is sorted so the data might not be sorted"

        # save the cached predictions
        if use_cache:
            self._save_cached_pred(dataset_hash, pred_data['mean'])

        return pred_data['mean']

    @staticmethod
    def _compute_dataset_hash(
        data: TimeSeriesDataFrame, 
        horizon: int, 
        window_size: int, 
        min_ctx: int, 
        context_length: int,
        known_covariates: Optional[TimeSeriesDataFrame] = None
    ) -> str:
        """
        Compute a unique string that identifies the time series dataset.
        copy from https://github.com/autogluon/autogluon/blob/9f0475abc7981fa55ccd8f917a98fb2c60a5c930/timeseries/src/autogluon/timeseries/trainer.py#L1124
        """
        from hashlib import md5

        def hash_pandas_df(df: Optional[pd.DataFrame]) -> str:
            """Compute a hash string for a pandas DataFrame."""
            if df is not None:
                # Convert in case TimeSeriesDataFrame object is passed
                df = pd.DataFrame(df, copy=True)
                df.reset_index(inplace=True)
                df.sort_index(inplace=True, axis=1)
                hashable_object = pd.util.hash_pandas_object(df).values
            else:
                hashable_object = "0".encode("utf-8")
            h = md5(hashable_object)
            h.update(str(horizon).encode("utf-8"))
            h.update(str(window_size).encode("utf-8"))
            h.update(str(min_ctx).encode("utf-8"))
            h.update(str(context_length).encode("utf-8"))
            return h.hexdigest()
        
        combined_hash = hash_pandas_df(data) + hash_pandas_df(known_covariates) + hash_pandas_df(data.static_features)
        return combined_hash

    def _get_cached_pred(
        self, dataset_hash: str
    ) -> Union[pd.Series, None]:
        """Load cached predictions for given dataset_hash from disk, if possible.

        If loading fails for any reason, empty dicts are returned.
        """
        if Path.exists(self._cached_predictions_path / f"{dataset_hash}.pkl"):
            try:
                return pd.read_pickle(self._cached_predictions_path / f"{dataset_hash}.pkl")
            except Exception:
                print(termcolor.colored(f"Cached predictions for {dataset_hash} are corrupted. Predictions will be made from scratch.", "yellow"))
                return None
        else:
            return None

    def _save_cached_pred(
        self,
        dataset_hash: str,
        pred: pd.Series,
    ) -> None:
        file_path = self._cached_predictions_path / f"{dataset_hash}.pkl"
        os.makedirs(file_path.parent, exist_ok=True)
        pred.to_pickle(file_path)

    def _predict_(self, context_data: TimeSeriesDataFrame, use_cache: bool=True, **kwargs) -> pd.Series:
        # compute the hash of the context data
        data_hash = self._compute_dataset_hash(context_data, horizon=0, window_size=0, min_ctx=0, context_length=0)

        # check if cached predictions are available
        if use_cache:
            cached_predictions = self._get_cached_pred(data_hash)
            if cached_predictions is not None:
                print(termcolor.colored(f"Using cached predictions for {self.model_name}", "green"))
                return cached_predictions

        pred = self.predict(data=context_data, model=self.best_model_name)['mean']

        # save the cached predictions
        if use_cache:
            self._save_cached_pred(data_hash, pred)

        return pred


class CovariateRegressor(GlobalCovariateRegressor):
    """
    A GlobalCovariateRegressor that also tracks the model name.
    """

    @property
    def _cached_predictions_path(self) -> Path:
        return Path(self.path) / "cached_predictions" if hasattr(self, 'path') else None

    def fit_and_predict(self, data: TimeSeriesDataFrame, time_limit: Optional[int], window_size: Optional[int]=None, random_seed: Optional[int]=None, **kwargs) -> pd.Series:
        # Fit the model if not already fit or if refit is required
        if not self.is_fit() or self.refit_during_predict:
            self.fit(data, time_limit=time_limit)

        # get cache
        data_hash = self._compute_dataset_hash(data, horizon=0, window_size=window_size if window_size else 0, min_ctx=0, context_length=0)
        if True and self._cached_predictions_path:
            cached_predictions = self._get_cached_pred(data_hash)
            if cached_predictions is not None:
                print(termcolor.colored(f"Using cached predictions for {self.model_name}", "green"))
                return cached_predictions

        # Get model predictions if the model is enabled
        if not self.disabled:
            pred_data = data.slice_by_timestep(-window_size, None) if window_size else data
            y_pred = self._predict(pred_data, static_features=data.static_features)
            y_series = pd.Series(y_pred, index=pred_data.index)

            # save cache
            if True and self._cached_predictions_path:
                self._save_cached_pred(data_hash, y_series)

            return y_series
        else:
            return None
    
    def _predict_(self, known_covariates: TimeSeriesDataFrame, static_features: pd.DataFrame=None, **kwargs) -> pd.Series:
        assert known_covariates is not None, "known_covariates must be provided"
        assert static_features is None, "datahash only support no static features"

        # get cache
        data_hash = self._compute_dataset_hash(known_covariates, horizon=0, window_size=0, min_ctx=0, context_length=0)
        if True and self._cached_predictions_path:
            cached_predictions = self._get_cached_pred(data_hash)
            if cached_predictions is not None:
                print(termcolor.colored(f"Using cached predictions for {self.model_name}", "green"))
                return cached_predictions

        y_pred = pd.Series(self._predict(data=known_covariates, static_features=static_features), index=known_covariates.index)
        
        # save cache
        if True and self._cached_predictions_path:
            self._save_cached_pred(data_hash, y_pred)

        return y_pred

    def save(self, path: Union[str, Path]) -> None:
        """Save the regressor to a file.
        
        Args:
            path: Path where the regressor will be saved. If path is a directory, 
                 the regressor will be saved as 'regressor.pkl' in that directory.
                 
        Returns:
            None
            
        Raises:
            IOError: If the file cannot be written.
        """
        path = Path(path)
        if path.is_dir():
            path = path / "regressor.pkl"
        
        # Create directory if it doesn't exist
        os.makedirs(path.parent, exist_ok=True)
        
        # Prepare a dict with all necessary attributes
        state_dict = {
            'target': self.target,
            'model_name': self.model_name,
            'model_hyperparameters': self.model_hyperparameters,
            'refit_during_predict': self.refit_during_predict,
            'tabular_eval_metric': self.tabular_eval_metric,
            'max_num_samples': self.max_num_samples,
            'validation_fraction': self.validation_fraction,
            'fit_time_fraction': self.fit_time_fraction,
            'include_static_features': self.include_static_features,
            'include_item_id': self.include_item_id,
            'disabled': self.disabled,
            'covariate_metadata': self.covariate_metadata,
            'model': self.model
        }
        
        with open(path, 'wb') as f:
            pickle.dump(state_dict, f)
        logger.info(f"Regressor saved successfully to {path}")
   
    @classmethod
    def load(cls, path: Union[str, Path]) -> "CovariateRegressor":
        """Load a regressor from a file.
        
        Args:
            path: Path to the saved regressor file. If path is a directory,
                 the function will look for 'regressor.pkl' in that directory.
                 
        Returns:
            A CovariateRegressor instance.
            
        Raises:
            FileNotFoundError: If the file doesn't exist.
            ValueError: If the loaded object is not a valid regressor state.
        """
        path = Path(path)
        if path.is_dir():
            path = path / "regressor.pkl"
            
        if not path.exists():
            raise FileNotFoundError(f"Regressor file not found: {path}")
        
        with open(path, 'rb') as f:
            state_dict = pickle.load(f)
            
        # Create a new instance with the saved parameters
        instance = cls(
            model_name=state_dict['model_name'],
            model_hyperparameters=state_dict['model_hyperparameters'],
            eval_metric=state_dict['tabular_eval_metric'],
            refit_during_predict=state_dict['refit_during_predict'],
            max_num_samples=state_dict['max_num_samples'],
            covariate_metadata=state_dict['covariate_metadata'],
            target=state_dict['target'],
            validation_fraction=state_dict['validation_fraction'],
            fit_time_fraction=state_dict['fit_time_fraction'],
            include_static_features=state_dict['include_static_features'],
            include_item_id=state_dict['include_item_id'],
        )
        
        # Set the state that couldn't be passed to __init__
        instance.model = state_dict['model']
        instance.disabled = state_dict['disabled']
        
        logger.info(f"Regressor loaded successfully from {path}")
        return instance

    @staticmethod
    def _compute_dataset_hash(
        data: TimeSeriesDataFrame, 
        horizon: int, 
        window_size: int, 
        min_ctx: int, 
        context_length: int,
        known_covariates: Optional[TimeSeriesDataFrame] = None
    ) -> str:
        """
        Compute a unique string that identifies the time series dataset.
        copy from https://github.com/autogluon/autogluon/blob/9f0475abc7981fa55ccd8f917a98fb2c60a5c930/timeseries/src/autogluon/timeseries/trainer.py#L1124
        """
        from hashlib import md5

        def hash_pandas_df(df: Optional[pd.DataFrame]) -> str:
            """Compute a hash string for a pandas DataFrame."""
            if df is not None:
                # Convert in case TimeSeriesDataFrame object is passed
                df = pd.DataFrame(df, copy=True)
                df.reset_index(inplace=True)
                df.sort_index(inplace=True, axis=1)
                hashable_object = pd.util.hash_pandas_object(df).values
            else:
                hashable_object = "0".encode("utf-8")
            h = md5(hashable_object)
            h.update(str(horizon).encode("utf-8"))
            h.update(str(window_size).encode("utf-8"))
            h.update(str(min_ctx).encode("utf-8"))
            h.update(str(context_length).encode("utf-8"))
            return h.hexdigest()
        
        combined_hash = hash_pandas_df(data) + hash_pandas_df(known_covariates) + hash_pandas_df(data.static_features)
        return combined_hash

    def _get_cached_pred(
        self, dataset_hash: str
    ) -> Union[pd.Series, None]:
        """Load cached predictions for given dataset_hash from disk, if possible.

        If loading fails for any reason, empty dicts are returned.
        """
        if Path.exists(self._cached_predictions_path / f"{dataset_hash}.pkl"):
            try:
                return pd.read_pickle(self._cached_predictions_path / f"{dataset_hash}.pkl")
            except Exception:
                print(termcolor.colored(f"Cached predictions for {dataset_hash} are corrupted. Predictions will be made from scratch.", "yellow"))
                return None
        else:
            return None

    def _save_cached_pred(
        self,
        dataset_hash: str,
        pred: pd.Series,
    ) -> None:
        file_path = self._cached_predictions_path / f"{dataset_hash}.pkl"
        os.makedirs(file_path.parent, exist_ok=True)
        pred.to_pickle(file_path)


class CrossSectionalRegressor:
    """
    An ensemble of GlobalCovariateRegressor models.
    
    This class maintains a list of GlobalCovariateRegressor models and combines
    their predictions using a weighted ensemble approach.
    
    Parameters
    ----------
    model_names : List[str]
        List of model names to use in the ensemble (e.g., ["XGB", "RF", "CAT"])
    target : str
        Name of the target column
    covariate_metadata : CovariateMetadata
        Metadata about covariates (static, known, past)
    include_static_features : bool, optional
        Whether to include static features in the model
    include_item_id : bool, optional
        Whether to include item_id as a feature
    models_hyperparameters : Dict[str, Dict[str, Any]], optional
        Hyperparameters for each model, keyed by model name
    aggregation_strategy : Union[str, Tuple[str, Dict[str, Any]], EnsembleAggregator], optional
        Strategy for aggregating model predictions. Must be exactly one of:
        - string: One of 'equal', 'performance', 'adaptive', or 'spa'
        - tuple: (strategy_name, kwargs_dict) where strategy_name is one of the above
          and kwargs_dict contains parameters for that specific strategy
        - EnsembleAggregator instance: A pre-configured aggregator instance
    fit_time_fraction : float, optional
        Fraction of time limit to allocate to model fitting
    validation_fraction : float, optional 
        Fraction of data to use for validation
    eval_metric : str, optional
        Evaluation metric to use for model validation
    """
    
    # Valid aggregation strategies
    VALID_STRATEGIES = {
        'equal', 'performance', 'adaptive', 'spa', 'singlebest', 'linear'
    }
    AVAILABLE_TABULAR_MODELS = tabular_ag_model_register.key_to_cls_map()
    AVAILABLE_TS_MODELS = MODEL_TYPES
    
    # Mapping from TimeSeriesPredictor metrics to regressor metrics
    _METRIC_MAPPING = {
        # Add identity mappings for cases where users pass the full name
        "mean_absolute_error": "MAE",
        "mean_squared_error": "MSE",
        "root_mean_squared_error": "RMSE",
        "mean_absolute_percentage_error": "MAPE",
        "symmetric_mean_absolute_percentage_error": "SMAPE",
        "mean_absolute_scaled_error": "MASE",
        "weighted_quantile_loss": "WQL",
        "scaled_quantile_loss": "SQL",
    }

    def __init__(
        self,
        model_names: List[str],
        target: str,
        covariate_metadata: CovariateMetadata,
        prediction_length: int,
        include_static_features: bool = True,
        include_item_id: bool = True,
        models_hyperparameters: Optional[Dict[str, Dict[str, Any]]] = None,
        aggregation_strategy: Union[str, Tuple[str, Dict[str, Any]], EnsembleAggregator] = 'equal',
        aggregation_train_length: Optional[int] = None,
        fit_time_fraction: float = 0.5,
        validation_fraction: float = 0.1,
        eval_metric: str = "mean_absolute_error",
        random_seed: int = 123,
        verbosity: int = 0,
    ):
        """Initialize the ensemble regressor with the given parameters."""
        # Store configuration parameters
        self.model_names = model_names
        self.target = target
        self.covariate_metadata = covariate_metadata
        self.prediction_length = prediction_length
        self.include_static_features = include_static_features
        self.include_item_id = include_item_id
        self.models_hyperparameters = models_hyperparameters or {}
        self.fit_time_fraction = fit_time_fraction
        self.validation_fraction = validation_fraction
        self.eval_metric = eval_metric
        self.random_seed = random_seed
        self.verbosity = verbosity
        # Initialize model aggregator
        self.aggregator = self._initialize_aggregator(aggregation_strategy)
        self.aggregation_train_length = aggregation_train_length
        
        # Initialize the ensemble of regressors
        self.regressors = self._initialize_regressors()
    
    def _initialize_aggregator(self, aggregation_strategy):
        """
        Initialize the ensemble aggregator based on the specified strategy.
        
        Parameters
        ----------
        aggregation_strategy : Union[str, Tuple[str, Dict[str, Any]], EnsembleAggregator]
            The aggregation strategy to use, which must be exactly one of:
            - string: name of the strategy
            - tuple: (strategy_name, kwargs_dict)
            - EnsembleAggregator instance
            
        Returns
        -------
        EnsembleAggregator
            The initialized aggregator instance
            
        Raises
        ------
        ValueError
            If an invalid or multiple aggregation strategies are specified
        """
        # If an aggregator instance is provided, return it directly
        if isinstance(aggregation_strategy, EnsembleAggregator):
            return aggregation_strategy
        
        # If a tuple is provided, extract the strategy name and kwargs
        if isinstance(aggregation_strategy, tuple) and len(aggregation_strategy) == 2:
            strategy_name, kwargs = aggregation_strategy
        else:
            # Otherwise, assume it's just a strategy name with no kwargs
            strategy_name, kwargs = aggregation_strategy, {}
        
        # Validate that the strategy is one of the allowed values
        strategy_name = str(strategy_name).lower()
        if strategy_name not in self.VALID_STRATEGIES:
            raise ValueError(
                f"Invalid aggregation strategy: '{strategy_name}'. "
                f"Must be one of: {', '.join(self.VALID_STRATEGIES)}"
            )
        
        # Set num_models for all aggregators
        kwargs['num_models'] = len(self.model_names)
            
        # Create the appropriate aggregator based on the strategy name
        if strategy_name == 'equal':
            return EqualWeightAggregator(**kwargs)
        elif strategy_name == 'performance':
            return PerformanceWeightAggregator(**kwargs)
        elif strategy_name == 'adaptive':
            return AdaptiveAggregator(**kwargs)
        elif strategy_name == 'spa':
            return SPAAggregator(**kwargs)
        elif strategy_name == 'singlebest':
            return SingleBestAggregator(**kwargs)
        elif strategy_name == 'linear':
            return LinearAggregator(**kwargs)
        else:
            raise ValueError(f"Invalid aggregation strategy: '{strategy_name}'")
    
    def _initialize_regressors(self):
        """Initialize the individual regression models in the ensemble."""
        regressors = []
        
        for model_name in self.model_names:
            # Get model-specific hyperparameters if available
            hyperparams = self.models_hyperparameters.get(model_name, {}).copy()
            
            # Check if model should be loaded from a path
            model_path = hyperparams.pop('path', None)
            
            # Initialize or load the regressor based on model type
            if model_name in self.AVAILABLE_TABULAR_MODELS:
                regressor = self._init_tabular_regressor(model_name, hyperparams, model_path)
            elif model_name in self.AVAILABLE_TS_MODELS:
                regressor = self._init_ts_regressor(model_name, hyperparams, model_path)
            else:
                available_models = list(self.AVAILABLE_TS_MODELS.keys()) + list(self.AVAILABLE_TABULAR_MODELS.keys())
                raise ValueError(f"Invalid model name: '{model_name}', available models: {available_models}")
            
            regressors.append(regressor)
        
        return regressors

    def _init_tabular_regressor(self, model_name, hyperparams, model_path=None):
        """Initialize or load a tabular regressor.
        
        Args:
            model_name: Name of the tabular model.
            hyperparams: Model hyperparameters.
            model_path: Optional path to load model from.
            
        Returns:
            Initialized CovariateRegressor.
        """
        if model_path:
            logger.info(f"Loading tabular regressor {model_name} from {model_path}")
            reg = CovariateRegressor.load(model_path)
            reg.path = model_path
            return reg
        
        # Initialize a new regressor
        return CovariateRegressor(
            model_name=model_name,
            target=self.target,
            covariate_metadata=self.covariate_metadata,
            include_static_features=self.include_static_features,
            include_item_id=self.include_item_id,
            model_hyperparameters=hyperparams,
            fit_time_fraction=self.fit_time_fraction,
            validation_fraction=self.validation_fraction,
            eval_metric=self.eval_metric,
        )

    def _init_ts_regressor(self, model_name, hyperparams, model_path=None):
        """Initialize or load a time series regressor.
        
        Args:
            model_name: Name of the time series model.
            hyperparams: Model hyperparameters.
            model_path: Optional path to load model from.
            
        Returns:
            Initialized TimeSeriesRegressor.
        """
        if model_path:
            logger.info(f"Loading time series regressor {model_name} from {model_path}")
            
            # Load the base predictor
            regressor = TimeSeriesPredictor.load(model_path, require_version_match=False)
            assert regressor.prediction_length == self.prediction_length, f"prediction_length of loaded regressor {model_name} is not equal to the prediction_length of the ensemble regressor. Please use the same prediction_length for all regressors."
            
            return regressor
        
        # Initialize a new regressor
        return TimeSeriesRegressor(
            model_name=model_name,
            model_hyperparameters=hyperparams,
            target=self.target,
            eval_metric=self._METRIC_MAPPING.get(self.eval_metric, self.eval_metric),
            prediction_length=self.prediction_length,
            verbosity=self.verbosity,
        )

    def fit_transform(
        self, 
        data: TimeSeriesDataFrame, 
        time_limit: Optional[int] = None,
        context_length: Optional[int] = None,
        include_individual_residuals: bool = False,
        keep_target_column: bool = False,
    ) -> TimeSeriesDataFrame:
        """
        Fit each regressor and transform the data by combining their predictions.
        
        Parameters
        ----------
        data : TimeSeriesDataFrame
            Input data for training
        time_limit : int, optional
            Time limit for fitting in seconds
        context_length : int, optional
            Context length for prediction (default=None, meaning use the context length of the model)
        include_individual_residuals : bool, optional
            Whether to include individual model residuals in the output
        keep_target_column : bool, optional
            Whether to keep the original target column in the output
            
        Returns
        -------
        TimeSeriesDataFrame
            Transformed data with residuals
        """
        if not self.regressors:
            return data
        
        # Calculate time limit per regressor if provided
        per_regressor_time_limit = time_limit // len(self.regressors) if time_limit else None
        
        # Fit models and collect predictions
        model_predictions = self._fit_and_collect_predictions(data, per_regressor_time_limit, context_length)
        
        # Stack predictions and fit the aggregator
        combined_predictions = self._combine_predictions(model_predictions, data[self.target], train_length=self.aggregation_train_length)
        combined_predictions.to_csv("/home/magics/hdd/sky_ws/residual_ws/tests/hopformer/data/regressor_context.csv")
        
        # Calculate residuals and prepare the output DataFrame
        return self._prepare_output_dataframe(
            data=data, 
            target=data[self.target], 
            combined_predictions=combined_predictions, 
            model_predictions=model_predictions,
            keep_target_column=keep_target_column, 
            include_individual_residuals=include_individual_residuals
        )
    
    def _fit_and_collect_predictions(self, data: TimeSeriesDataFrame, time_limit: Optional[int], context_length: Optional[int]) -> List[pd.Series]:
        """Fit models and collect their predictions."""
        model_predictions = []
        
        for i, regressor in enumerate(self.regressors):
            print(f"Fitting regressor {i+1}/{len(self.regressors)}: {self.model_names[i]}")

            y_pred = regressor.fit_and_predict(data, time_limit, context_length=context_length, random_seed=self.random_seed)
            if y_pred is not None:
                model_predictions.append(y_pred)
        
        return model_predictions
    
    def _combine_predictions(self, model_predictions: List[pd.Series], target: pd.Series, train_length: Optional[int]=None) -> pd.Series:
        """Combine individual model predictions using the aggregator."""
        if not model_predictions:
            return None
        
        # Prepare training data
        if train_length is None:
            # Use all data when train_length is None
            x = np.column_stack([pred.values for pred in model_predictions])
            y = target.values
        else:
            # Use only the last train_length points when specified
            x = np.column_stack([pred.groupby(level=0, group_keys=False).tail(train_length).values for pred in model_predictions])
            y = target.groupby(level=0, group_keys=False).tail(train_length).values
        
        # Fit the aggregator
        self.aggregator.fit(x, y)
        y_pred = self.aggregator.predict(np.column_stack([pred.values for pred in model_predictions]))
        
        # Get the combined predictions
        return pd.Series(y_pred, index=target.index)
    
    def _prepare_output_dataframe(
        self, 
        data: TimeSeriesDataFrame, 
        target: pd.Series, 
        combined_predictions: pd.Series, 
        model_predictions: List[pd.Series], 
        keep_target_column: bool, 
        include_individual_residuals: bool
    ):
        """Prepare the output DataFrame with residuals."""
        # Create a copy of the input data
        result_df = data.copy()
        
        # Calculate residuals (target - prediction)
        residuals = target.values - combined_predictions.values
        result_df[self.target] = residuals
        
        # Keep the original target column if requested
        if keep_target_column:
            result_df[f"{self.target}_label"] = target.values
        
        # Add individual model residuals if requested
        if include_individual_residuals:
            for i, prediction in enumerate(model_predictions):
                col_name = f"{self.target}_residual_{self.model_names[i]}"
                result_df[col_name] = target.values - prediction.values
        
        return result_df
    
    def inverse_transform(
        self,
        predictions: TimeSeriesDataFrame,
        known_covariates: TimeSeriesDataFrame,
        static_features: Optional[pd.DataFrame] = None,
        include_individual_predictions: bool = False,
        context_data: Optional[TimeSeriesDataFrame] = None,
    ) -> TimeSeriesDataFrame:
        """
        Apply the inverse transformation by combining regressor outputs.
        
        Parameters
        ----------
        predictions : TimeSeriesDataFrame
            Predicted residuals
        known_covariates : TimeSeriesDataFrame
            Known covariates for the prediction period
        static_features : pd.DataFrame, optional
            Static features
        include_individual_predictions : bool, optional
            If True, include columns for each individual model's predictions
            
        Returns
        -------
        TimeSeriesDataFrame
            Inverse transformed predictions
        """
        if not self.regressors:
            return predictions
        
        # this need to be done here 
        known_covariates = self._align_covariates_with_forecast_index(known_covariates=known_covariates, data=context_data)
        if known_covariates is not None:
            assert predictions.index.equals(known_covariates.index), "Index of residuals predictions and known_covariates must be the same"
        
        # Get predictions from each model
        model_predictions = self._get_model_predictions(known_covariates, static_features, context_data)
        
        # Combine model predictions
        if model_predictions:
            pred_array = np.column_stack([pred.values for pred in model_predictions])
            covariate_effect = pd.Series(self.aggregator.predict(pred_array), index=predictions.index)
            covariate_effect.to_csv("/home/magics/hdd/sky_ws/residual_ws/tests/hopformer/data/regressor_predictions.csv")

            np.set_printoptions(formatter={'float': '{:.2f}'.format})
            print(f"{type(self.aggregator).__name__}(normalizer: {self.aggregator.normalizer}) coef: {self.aggregator.coef_}")
            np.set_printoptions(formatter=None)
            
            # Create the final output with combined predictions
            final_predictions = self._create_final_predictions(
                residual_predictions=predictions, 
                covariate_effect=covariate_effect, 
                model_predictions=model_predictions, 
                include_individual_predictions=include_individual_predictions
            )
            
            return final_predictions
        
        return predictions
    
    def _align_covariates_with_forecast_index(
        self,
        known_covariates: Optional[TimeSeriesDataFrame],
        data: TimeSeriesDataFrame,
    ) -> Optional[TimeSeriesDataFrame]:
        """Select the relevant item_ids and timestamps from the known_covariates dataframe.

        If some of the item_ids or timestamps are missing, an exception is raised.
        copy from https://github.com/autogluon/autogluon/blob/73173f69a04454d6f9da2066aefe035bac51336b/timeseries/src/autogluon/timeseries/learner.py#L122
        """
        if (self.covariate_metadata.known_covariates is None) or (len(self.covariate_metadata.known_covariates) == 0):
            return None
        else:
            assert known_covariates is not None

        if self.target in known_covariates.columns:
            known_covariates = known_covariates.drop(self.target, axis=1)

        missing_item_ids = data.item_ids.difference(known_covariates.item_ids)
        if len(missing_item_ids) > 0:
            raise ValueError(
                f"known_covariates are missing information for the following item_ids: {reprlib.repr(missing_item_ids.to_list())}."
            )

        forecast_index = pd.MultiIndex.from_frame(
            make_future_data_frame(data, prediction_length=self.prediction_length, freq=data.freq)
        )
        try:
            known_covariates = known_covariates.loc[forecast_index]  # type: ignore
        except KeyError:
            raise ValueError(
                "`known_covariates` should include the `item_id` and `timestamp` values covering the forecast horizon "
                "(i.e., the next `prediction_length` time steps following the end of each time series in the input "
                "data). Use `TimeSeriesPredictor.make_future_data_frame` to generate the required `item_id` and "
                "`timestamp` combinations for the `known_covariates`."
            )
        return known_covariates
    
    def _get_model_predictions(self, known_covariates: TimeSeriesDataFrame, static_features: pd.DataFrame, context_data: TimeSeriesDataFrame) -> List[pd.Series]:
        """Get predictions from each model in the ensemble."""
        model_predictions = []
        
        for i, regressor in enumerate(self.regressors):
            print(f"Predicting with regressor {i+1}/{len(self.regressors)}: {self.model_names[i]}")
            
            pred = regressor._predict_(known_covariates=known_covariates, static_features=static_features, context_data=context_data)
            model_predictions.append(pred)
        
        return model_predictions
    
    def _create_final_predictions(
        self, 
        residual_predictions: TimeSeriesDataFrame, 
        covariate_effect: pd.Series, 
        model_predictions: List[pd.Series], 
        include_individual_predictions: bool
    ):
        """Create the final predictions by combining residual predictions with covariate effects."""
        # Add covariate effect to each prediction column
        final_predictions = residual_predictions.copy()
        
        for col in residual_predictions.columns:
            final_predictions[col] = residual_predictions[col] + covariate_effect.values
        
        # Add individual model predictions if requested
        if include_individual_predictions:
            for i, prediction in enumerate(model_predictions):
                col_name = f"{self.target}_{self.model_names[i]}"
                final_predictions[col_name] = prediction.values
        
        return final_predictions


def reshape_residuals_to_2d(residuals_series, fill_value=np.nan):
    """
    Reshape a residuals Series with MultiIndex (item_id, timestamp) into a 2D array.
    
    Parameters:
    -----------
    residuals_series : pd.Series
        Series with MultiIndex (item_id, timestamp) containing residual values
    fill_value : float, optional
        Value to use for missing timestamps, default is np.nan
        
    Returns:
    --------
    reshaped_array : np.ndarray
        2D array with shape (n_items, max_timestamps)
    item_ids : np.ndarray
        Array of item_ids corresponding to rows of reshaped_array
    all_timestamps : np.ndarray
        Array of all timestamps corresponding to columns of reshaped_array
    """
    # Extract item_ids and timestamps
    idx = residuals_series.index
    item_ids = idx.get_level_values('item_id').unique()
    
    # Check if all items have the same number of timestamps
    counts = residuals_series.groupby(level='item_id').size()
    uniform_length = counts.nunique() == 1
    
    if uniform_length:
        # Fast path: all items have the same number of timestamps
        time_length = counts.iloc[0]
        all_timestamps = idx.get_level_values(1).unique()
        
        # We can simply reshape the values directly
        reshaped_array = residuals_series.values.reshape(len(item_ids), time_length)
        return reshaped_array, item_ids, all_timestamps
    
    # Slow path: items have different numbers of timestamps
    all_timestamps = idx.get_level_values(1).unique()
    
    # Create empty 2D array with shape (n_items, max_timestamps)
    reshaped_array = np.full((len(item_ids), len(all_timestamps)), fill_value)
    
    # Create mapping from timestamp to column index
    timestamp_to_col = {ts: i for i, ts in enumerate(all_timestamps)}
    
    # Fill the array with values
    for i, item_id in enumerate(item_ids):
        try:
            # Get residuals for this item
            item_data = residuals_series.loc[item_id]
            
            # Handle case where there's only one timestamp for this item
            if not isinstance(item_data, pd.Series):
                item_data = pd.Series([item_data], index=[idx[idx.get_level_values('item_id') == item_id].get_level_values(1)[0]])
            
            # Map each timestamp to the right column
            for ts, value in item_data.items():
                if ts in timestamp_to_col:
                    col_idx = timestamp_to_col[ts]
                    reshaped_array[i, col_idx] = value
        except Exception as e:
            print(f"Error processing item_id {item_id}: {e}")
    
    return reshaped_array, item_ids, all_timestamps