import pandas as pd
import numpy as np
import unittest
from residual_chronos.Predictor import Predictor
from autogluon.timeseries.dataset.ts_dataframe import TimeSeriesDataFrame

class TestPredictor(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        """Set up data once for all tests"""
        # Load data
        sales_data = pd.read_csv("tests/data/simulated_sales_data.csv")
        static_features = pd.read_csv("tests/data/simulated_static_features.csv")

        # Create TimeSeriesDataFrame
        cls.ts_data = TimeSeriesDataFrame(
            sales_data,
            id_column="item_id",
            timestamp_column="timestamp",
            static_features=static_features
        )

        # Extract context and known covariates for prediction
        cls.prediction_length = 8
        cls.RANDOM_SEED = 100
        cls.context = cls.ts_data.slice_by_timestep(None, -cls.prediction_length)
        cls.known_covariates = cls.ts_data.slice_by_timestep(-cls.prediction_length, None)
        cls.static_features = cls.ts_data.static_features
        
        # Calculate mean sales per item from context data (used in test_without_regressor_predictor)
        cls.context_means = cls.context.groupby(level=0)["sales"].mean()

    def test_without_regressor_predictor(self):
        """Test model with neither regressor nor predictor"""
        print("\n=== TEST 1: Without regressor and predictor ===")
        
        # Initialize model
        model = Predictor(
            prediction_length=self.prediction_length,
            target="sales",
            eval_metric="MAE",
            regressor_types=None,
            bolt_model_path=None,
            random_seed=self.RANDOM_SEED,
            regressor_fit_time_fraction=0.5,
            regressor_validation_fraction=0.1
        )

        # Fit model
        model.fit(
            train_data=self.context,
            time_limit=120,
        )

        # Generate predictions
        predictions = model.predict(
            data=self.context,
            known_covariates=self.known_covariates,
            static_features=self.static_features
        )

        print("Predictions sample:")
        print(predictions.head())
        print("\nContext means:")
        print(self.context_means)

        # Verify that every value in the 'mean' column matches the context mean for that item
        all_match = True
        tolerance = 0.001  # 0.1% tolerance

        for item_id in predictions.index.get_level_values(0).unique():
            item_context_mean = self.context_means[item_id]
            item_predictions = predictions.loc[item_id]
            
            # Check if all prediction values equal the context mean (within tolerance)
            diffs = np.abs(item_predictions['mean'] - item_context_mean)
            max_allowed_diff = item_context_mean * tolerance
            match = (diffs <= max_allowed_diff).all()
            
            if not match:
                all_match = False
                print(f"Item {item_id}: Context mean = {item_context_mean:.2f}")
                print(f"  - Values that don't match: {item_predictions['mean'][diffs > max_allowed_diff].values}")
        
        self.assertTrue(all_match, "Not all prediction values match their corresponding item means")

    def test_with_only_regressor(self):
        """Test model with only regressor, no predictor"""
        print("\n=== TEST 2: With only regressor ===")
        
        # Initialize model
        model = Predictor(
            prediction_length=self.prediction_length,
            target="sales",
            eval_metric="MAE",
            regressor_types=["XGB"],
            bolt_model_path=None,
            random_seed=self.RANDOM_SEED,
            regressor_fit_time_fraction=0.5,
            regressor_validation_fraction=0.1
        )

        # Fit model
        model.fit(
            train_data=self.context,
            time_limit=120,
        )

        # Generate predictions
        predictions = model.predict(
            data=self.context,
            known_covariates=self.known_covariates,
            static_features=self.static_features
        )

        print("Predictions sample:")
        print(predictions.head())
        
        # Check that predictions don't contain NaN values
        self.assertFalse(
            predictions.isnull().any().any(), 
            "Predictions contain NaN values"
        )

    def test_with_only_predictor(self):
        """Test model with only predictor, no regressor"""
        print("\n=== TEST 3: With only predictor ===")
        
        # Initialize model
        model = Predictor(
            prediction_length=self.prediction_length,
            target="sales",
            eval_metric="MAE",
            regressor_types=None,
            bolt_model_path="bolt_small",
            random_seed=self.RANDOM_SEED,
            regressor_fit_time_fraction=0.5,
            regressor_validation_fraction=0.1
        )

        # Fit model
        model.fit(
            train_data=self.context,
            time_limit=120,
        )

        # Generate predictions
        predictions = model.predict(
            data=self.context,
            known_covariates=self.known_covariates,
            static_features=self.static_features
        )

        print("Predictions sample:")
        print(predictions.head())
        
        # Check that predictions don't contain NaN values
        self.assertFalse(
            predictions.isnull().any().any(), 
            "Predictions contain NaN values"
        )

if __name__ == "__main__":
    unittest.main()


