import unittest
import pandas as pd
import numpy as np
import os
import tempfile
import shutil
from pathlib import Path
import random
import torch

from autogluon.timeseries import TimeSeriesDataFrame
from residual_chronos.Predictor import Predictor

class TestPredictor(unittest.TestCase):
    """Test cases for Predictor class."""
    
    @classmethod
    def setUpClass(cls):
        """Set up test data once for all tests."""
        # Set random seeds for reproducibility
        random_seed = 42
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(random_seed)

        # Create test data
        cls._create_test_data()
        
        # Create temp directory for model saving/loading tests
        cls.temp_dir = tempfile.mkdtemp()
    
    @classmethod
    def tearDownClass(cls):
        """Clean up resources after tests."""
        # Remove temp directory
        shutil.rmtree(cls.temp_dir)
    
    @classmethod
    def _create_test_data(cls):
        """Create synthetic time series data for testing."""
        # Create time series data with:
        # - 10 items
        # - 20 time steps each
        # - Various covariates and static features
        n_items = 10
        n_timesteps = 20
        prediction_length = 4
        
        # Create item IDs
        item_ids = [f"item_{i}" for i in range(n_items)]
        
        # Create timestamps for each item
        timestamps = pd.date_range(start='2023-01-01', periods=n_timesteps, freq='D')
        
        # Create index for the time series data
        index = pd.MultiIndex.from_product([item_ids, timestamps], names=['item_id', 'timestamp'])
        
        # Create synthetic data
        data_dict = {
            # Target variable
            'target': np.sin(np.arange(len(index)) / 5) * 10 + np.random.normal(0, 1, len(index)),
            
            # Real-valued known covariates
            'price': np.random.uniform(5, 15, len(index)),
            'temperature': np.random.uniform(0, 30, len(index)),
            
            # Categorical known covariates
            'promotion': np.random.choice([0, 1], size=len(index), p=[0.8, 0.2]),
            'holiday': np.random.choice([0, 1], size=len(index), p=[0.9, 0.1]),
            
            # Real-valued past covariates
            'past_value1': np.random.normal(0, 1, len(index)),
            'past_value2': np.random.normal(0, 1, len(index)),
            
            # Categorical past covariates
            'past_cat1': np.random.choice(['A', 'B', 'C'], size=len(index)),
            'past_cat2': np.random.choice([0, 1, 2], size=len(index)),
        }
        
        # Create DataFrame
        data = pd.DataFrame(data_dict, index=index)
        
        # Convert categorical columns to category dtype
        data['promotion'] = data['promotion'].astype('category')
        data['holiday'] = data['holiday'].astype('category')
        data['past_cat1'] = data['past_cat1'].astype('category')
        data['past_cat2'] = data['past_cat2'].astype('category')
        
        # Create static features
        static_features = pd.DataFrame({
            # Real-valued static features
            'store_size': np.random.uniform(1000, 10000, n_items),
            'avg_traffic': np.random.uniform(100, 1000, n_items),
            
            # Categorical static features
            'store_type': np.random.choice(['A', 'B', 'C'], size=n_items),
            'region': np.random.choice(['North', 'South', 'East', 'West'], size=n_items),
        }, index=item_ids)
        
        # Convert categorical columns to category dtype
        static_features['store_type'] = static_features['store_type'].astype('category')
        static_features['region'] = static_features['region'].astype('category')
        
        # Create TimeSeriesDataFrame
        tsdf = TimeSeriesDataFrame(data)
        tsdf.static_features = static_features
        
        # Split into train and test
        train_data, test_data = tsdf.train_test_split(prediction_length=prediction_length)
        
        # Store data for tests
        cls.train_data = train_data
        cls.test_data = test_data
        cls.prediction_length = prediction_length
        
        # Extract context and known_covariates for prediction tests
        cls.context = test_data.slice_by_timestep(None, -prediction_length)
        cls.known_covariates = test_data.slice_by_timestep(-prediction_length, None)
    
    def test_default_initialization(self):
        """Test initialization with default parameters."""
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target"
        )
        self.assertEqual(model.prediction_length, self.prediction_length)
        self.assertEqual(model.target, "target")
        self.assertEqual(model.known_covariates_real, [])
        self.assertEqual(model.known_covariates_cat, [])
        self.assertEqual(model.static_features_real, [])
        self.assertEqual(model.static_features_cat, [])
        self.assertEqual(model.past_covariates_real, [])
        self.assertEqual(model.past_covariates_cat, [])
    
    def test_base_model_no_covariates(self):
        """Test model with no covariates or static features."""
        print("\n==== Testing model with no covariates or static features ====")
        
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target",
            eval_metric="MAE",
            regressor_types="XGB",
            bolt_model_path="bolt_small",
            random_seed=42
        )
        
        # Train the model with minimal time limit for quick testing
        model.fit(self.train_data, time_limit=60, enable_ensemble=False)
        
        # Generate forecasts
        forecasts = model.predict(self.context, known_covariates=self.known_covariates)
        
        # Verify forecasts have the expected shape and columns
        self.assertEqual(len(forecasts), len(self.context.item_ids) * self.prediction_length)
        self.assertIn("mean", forecasts.columns)
        
        # Evaluate the model
        scores = model.evaluate(self.test_data)
        print(f"Scores with no covariates: {scores}")
    
    def test_model_with_known_covariates_only(self):
        """Test model with only known covariates."""
        print("\n==== Testing model with only known covariates ====")
        
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target",
            known_covariates_real=["price", "temperature"],
            known_covariates_cat=["promotion", "holiday"],
            eval_metric="MAE",
            regressor_types="XGB",
            bolt_model_path="bolt_small",
            random_seed=42
        )
        
        # Train the model
        model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
        # Generate forecasts
        forecasts = model.predict(self.context, known_covariates=self.known_covariates)
        
        # Verify forecasts
        self.assertEqual(len(forecasts), len(self.context.item_ids) * self.prediction_length)
        self.assertIn("mean", forecasts.columns)
        
        # Evaluate the model
        scores = model.evaluate(self.test_data)
        print(f"Scores with known covariates only: {scores}")
    
    def test_model_with_static_features_only(self):
        """Test model with only static features."""
        print("\n==== Testing model with only static features ====")
        
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target",
            static_features_real=["store_size", "avg_traffic"],
            static_features_cat=["store_type", "region"],
            eval_metric="MAE",
            regressor_types="XGB",
            bolt_model_path="bolt_small",
            random_seed=42
        )
        
        # Train the model
        model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
        # Generate forecasts
        forecasts = model.predict(self.context, known_covariates=self.known_covariates)
        
        # Verify forecasts
        self.assertEqual(len(forecasts), len(self.context.item_ids) * self.prediction_length)
        self.assertIn("mean", forecasts.columns)
        
        # Evaluate the model
        scores = model.evaluate(self.test_data)
        print(f"Scores with static features only: {scores}")
    
    def test_model_with_past_covariates_only(self):
        """Test model with only past covariates."""
        print("\n==== Testing model with only past covariates ====")
        
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target",
            past_covariates_real=["past_value1", "past_value2"],
            past_covariates_cat=["past_cat1", "past_cat2"],
            eval_metric="MAE",
            regressor_types="XGB",
            bolt_model_path="bolt_small",
            random_seed=42
        )
        
        # Train the model
        model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
        # Generate forecasts
        forecasts = model.predict(self.context, known_covariates=self.known_covariates)
        
        # Verify forecasts
        self.assertEqual(len(forecasts), len(self.context.item_ids) * self.prediction_length)
        self.assertIn("mean", forecasts.columns)
        
        # Evaluate the model
        scores = model.evaluate(self.test_data)
        print(f"Scores with past covariates only: {scores}")
    
    def test_model_with_all_feature_types(self):
        """Test model with all types of features."""
        print("\n==== Testing model with all feature types ====")
        
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target",
            known_covariates_real=["price", "temperature"],
            known_covariates_cat=["promotion", "holiday"],
            static_features_real=["store_size", "avg_traffic"],
            static_features_cat=["store_type", "region"],
            past_covariates_real=["past_value1", "past_value2"],
            past_covariates_cat=["past_cat1", "past_cat2"],
            eval_metric="MAE",
            regressor_types="XGB",
            bolt_model_path="bolt_small",
            random_seed=42
        )
        
        # Train the model
        model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
        # Generate forecasts
        forecasts = model.predict(self.context, known_covariates=self.known_covariates)
        
        # Verify forecasts
        self.assertEqual(len(forecasts), len(self.context.item_ids) * self.prediction_length)
        self.assertIn("mean", forecasts.columns)
        
        # Evaluate the model
        scores = model.evaluate(self.test_data)
        print(f"Scores with all feature types: {scores}")
    
    def test_different_regressor_types(self):
        """Test different regressor types."""
        print("\n==== Testing different regressor types ====")
        
        for regressor_type in ["XGB", "CAT", "RF"]:
            print(f"\nTesting regressor_type: {regressor_type}")
            
            model = Predictor(
                prediction_length=self.prediction_length,
                target="target",
                known_covariates_real=["price"],
                static_features_real=["store_size"],
                eval_metric="MAE",
                regressor_types=regressor_type,
                bolt_model_path="bolt_small",
                random_seed=42
            )
            
            # Train the model
            model.fit(self.train_data, time_limit=30, enable_ensemble=False)
            
            # Generate forecasts
            forecasts = model.predict(self.context, known_covariates=self.known_covariates)
            
            # Verify forecasts
            self.assertEqual(len(forecasts), len(self.context.item_ids) * self.prediction_length)
            self.assertIn("mean", forecasts.columns)
            
            # Evaluate the model
            scores = model.evaluate(self.test_data)
            print(f"Scores with {regressor_type} regressor: {scores}")
    
    # def test_save_and_load(self):
    #     """Test saving and loading the model."""
    #     print("\n==== Testing model save and load ====")
        
    #     model = Predictor(
    #         prediction_length=self.prediction_length,
    #         target="target",
    #         known_covariates_real=["price"],
    #         static_features_real=["store_size"],
    #         eval_metric="MAE",
    #         regressor_types="XGB",
    #         bolt_model_path="bolt_small",
    #         random_seed=42
    #     )
        
    #     # Train the model
    #     model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
    #     # Generate forecasts before saving
    #     forecasts_before = model.predict(self.context, known_covariates=self.known_covariates)
        
    #     # Save the model
    #     save_path = os.path.join(self.temp_dir, "test_model")
    #     model.save(save_path)
        
    #     # Load the model
    #     loaded_model = Predictor.load(save_path)
        
    #     # Generate forecasts after loading
    #     forecasts_after = loaded_model.predict(self.context, known_covariates=self.known_covariates)
        
    #     # Verify forecasts match
    #     pd.testing.assert_frame_equal(forecasts_before, forecasts_after)
        
    #     print("Model successfully saved and loaded with identical predictions")
    
    def test_missing_known_covariates(self):
        """Test error handling when known covariates are missing."""
        print("\n==== Testing error handling for missing known covariates ====")
        
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target",
            known_covariates_real=["price", "temperature", "nonexistent_covariate"],
            eval_metric="MAE",
            regressor_types="XGB",
            bolt_model_path="bolt_small",
            random_seed=42
        )
        
        # Expect KeyError because of missing covariates
        with self.assertRaises(KeyError):
            model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
        print("Test passed: Correctly raised error for missing known covariates")
    
    def test_missing_static_features(self):
        """Test error handling when static features are missing."""
        print("\n==== Testing error handling for missing static features ====")
        
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target",
            static_features_real=["store_size", "nonexistent_feature"],
            eval_metric="MAE",
            regressor_types="XGB",
            bolt_model_path="bolt_small",
            random_seed=42
        )
        
        # Expect KeyError because of missing static features
        with self.assertRaises(KeyError):
            model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
        print("Test passed: Correctly raised error for missing static features")
    
    def test_missing_dynamic_features(self):
        """Test error handling when dynamic features (past covariates) are missing."""
        print("\n==== Testing error handling for missing dynamic features ====")
        
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target",
            past_covariates_real=["past_value1", "nonexistent_dynamic_feature"],
            eval_metric="MAE",
            regressor_types="XGB",
            bolt_model_path="bolt_small",
            random_seed=42
        )
        
        # Expect KeyError because of missing dynamic features
        with self.assertRaises(KeyError):
            model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
        print("Test passed: Correctly raised error for missing dynamic features")
    
    # def test_reset_model(self):
    #     """Test the reset_model functionality."""
    #     print("\n==== Testing reset_model functionality ====")
        
    #     model = Predictor(
    #         prediction_length=self.prediction_length,
    #         target="target",
    #         known_covariates_real=["price"],
    #         eval_metric="MAE",
    #         regressor_types="XGB",
    #         bolt_model_path="bolt_small",
    #         random_seed=42
    #     )
        
    #     # Train initial model
    #     model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
    #     # Generate initial forecasts
    #     initial_forecasts = model.predict(self.context, known_covariates=self.known_covariates)
        
    #     # Train model again with reset_model=True
    #     model.fit(self.train_data, time_limit=30, enable_ensemble=False, reset_model=True)
        
    #     # Generate new forecasts
    #     reset_forecasts = model.predict(self.context, known_covariates=self.known_covariates)
        
    #     # Check that forecasts are different (models were reset)
    #     # We just check they're not identical since random initialization would make them different
    #     self.assertFalse(initial_forecasts.equals(reset_forecasts))
        
    #     print("Test passed: reset_model functionality works correctly")
    
    def test_passing_explicit_static_features(self):
        """Test passing explicit static features to predict method."""
        print("\n==== Testing passing explicit static features ====")
        
        model = Predictor(
            prediction_length=self.prediction_length,
            target="target",
            known_covariates_real=["price"],
            static_features_real=["store_size"],
            static_features_cat=["store_type"],
            eval_metric="MAE",
            regressor_types="XGB",
            bolt_model_path="bolt_small",
            random_seed=42
        )
        
        # Train the model
        model.fit(self.train_data, time_limit=30, enable_ensemble=False)
        
        # Generate forecasts with static features from context
        forecasts_implicit = model.predict(self.context, known_covariates=self.known_covariates)
        
        # Get a subset of static features
        subset_static_features = self.train_data.static_features.copy()
        
        # Generate forecasts with explicitly provided static features
        forecasts_explicit = model.predict(
            self.context, 
            known_covariates=self.known_covariates, 
            static_features=subset_static_features
        )
        
        # Verify forecasts were generated
        self.assertEqual(len(forecasts_explicit), len(self.context.item_ids) * self.prediction_length)
        self.assertIn("mean", forecasts_explicit.columns)
        
        print("Test passed: Successfully predicted with explicit static features")
    
    # def test_incremental_training(self):
    #     """Test incremental training without resetting the model."""
    #     print("\n==== Testing incremental training ====")
        
    #     model = Predictor(
    #         prediction_length=self.prediction_length,
    #         target="target",
    #         known_covariates_real=["price"],
    #         eval_metric="MAE",
    #         regressor_types="XGB",
    #         bolt_model_path="bolt_small",
    #         random_seed=42
    #     )
        
    #     # Split train data in half
    #     train_size = len(self.train_data)
    #     train_data1 = self.train_data.iloc[:train_size//2]
    #     train_data2 = self.train_data.iloc[train_size//2:]
        
    #     # Train on first half
    #     model.fit(train_data1, time_limit=30, enable_ensemble=False)
        
    #     # Generate forecasts after first training
    #     forecasts1 = model.predict(self.context, known_covariates=self.known_covariates)
        
    #     # Continue training with second half without resetting
    #     model.fit(train_data2, time_limit=30, enable_ensemble=False, reset_model=False)
        
    #     # Generate forecasts after second training
    #     forecasts2 = model.predict(self.context, known_covariates=self.known_covariates)
        
    #     # Forecasts should be different due to additional training data
    #     self.assertFalse(forecasts1.equals(forecasts2))
        
    #     print("Test passed: Incremental training works correctly")

if __name__ == "__main__":
    unittest.main() 