import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import random
from pathlib import Path
import os
import torch
import argparse

from autogluon.timeseries import TimeSeriesDataFrame
from autogluon.timeseries.utils.features import CovariateMetadata
from residual_chronos.Regressor import CrossSectionalRegressor

# Set random seeds for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

def test_ensemble_regressor(data, plot_results=True, plot_path="tests/results/plots"):
    """Test the CrossSectionalRegressor with the generated data."""
    if not os.path.exists(plot_path):
        os.makedirs(plot_path)
        
    # Extract the target and other covariates
    target = 'sales'
    
    # Define covariate metadata
    covariate_metadata = CovariateMetadata(
        # static_features_cat=['store_type'],
        # static_features_real=['store_size'],
        known_covariates_real=['temperature', 'price'],
        known_covariates_cat=['promotion']
    )
    
    # Create an ensemble regressor
    model_names = ["XGB", "RF", "CAT"]
    ensemble_regressor = CrossSectionalRegressor(
        model_names=model_names,
        target=target,
        covariate_metadata=covariate_metadata,
        include_static_features=True,
        include_item_id=True,
        models_hyperparameters={
            "XGB": {"learning_rate": 0.1, "max_depth": 5},
            "RF": {"n_estimators": 100},
            "CAT": {"iterations": 100, "depth": 5}
        },
        fit_time_fraction=0.5,
        validation_fraction=0.1,
        eval_metric="mean_absolute_error"
    )
    
    # Fit and transform the data to get residuals
    residuals = ensemble_regressor.fit_transform(data, keep_target_column=True, include_individual_residuals=True)
    
    if plot_results:
        # Plot original data vs. residuals and predictions
        plt.figure(figsize=(18, 15))
        
        # Get unique store IDs
        stores = data.static_features.index.tolist()
        
        # Plot for each store (limit to first 3 stores for clarity)
        for i, store in enumerate(stores[:3]):
            # Extract data for this store
            store_data = data.loc[store]
            store_residuals = residuals.loc[store]
            
            # Get timestamps for plotting
            timestamps = store_data.index
            
            # 1. Plot 1: Original Sales vs Ensemble Residuals
            plt.subplot(3, 3, i*3+1)
            plt.plot(timestamps, store_residuals['sales_label'], 'b-', label='Original Sales')
            plt.plot(timestamps, store_residuals['sales'], 'r-', label='Ensemble Residuals')
            plt.title(f'Store {store} - Original vs Ensemble Residuals')
            plt.ylabel('Value')
            plt.legend()
            
            # 2. Plot 2: All Model Residuals Comparison
            plt.subplot(3, 3, i*3+2)
            # Plot ensemble residuals first
            plt.plot(timestamps, store_residuals['sales'], 'r-', label='Ensemble')
            
            # Define colors for different models
            colors = ['g-', 'c-', 'm-', 'y-', 'k-', 'b--', 'g--', 'c--', 'm--']
            
            # Plot individual model residuals
            for j, model_name in enumerate(model_names):
                col_name = f'sales_residual_{model_name}'
                if col_name in store_residuals.columns:
                    color_idx = min(j, len(colors)-1)  # Ensure we don't exceed color list
                    plt.plot(timestamps, store_residuals[col_name], colors[color_idx], label=f'{model_name}')
            
            plt.title(f'Store {store} - Residuals by Model')
            plt.ylabel('Residual Value')
            plt.legend()
            
            # 3. Plot 3: Reconstructed Predictions (residuals + original)
            plt.subplot(3, 3, i*3+3)
            original = store_residuals['sales_label']
            plt.plot(timestamps, original, 'b-', label='Original Sales')
            plt.plot(timestamps, original + store_residuals['sales'], 'r-', label='Ensemble Prediction')
            
            # Plot individual model predictions
            for j, model_name in enumerate(model_names):
                col_name = f'sales_residual_{model_name}'
                if col_name in store_residuals.columns:
                    color_idx = min(j, len(colors)-1)
                    plt.plot(timestamps, original + store_residuals[col_name], colors[color_idx], label=f'{model_name} Prediction')
            
            plt.title(f'Store {store} - Reconstructed Predictions')
            plt.ylabel('Sales Value')
            plt.legend()
        
        plt.tight_layout()
        plt.savefig(f"{plot_path}/ensemble_regressor_detailed_plots.png")
        print(f"Ensemble regressor detailed plots saved to {plot_path}/ensemble_regressor_detailed_plots.png")
        
        # Create a second figure that shows the residuals more clearly
        plt.figure(figsize=(15, 12))
        for i, store in enumerate(stores[:3]):
            store_residuals = residuals.loc[store]
            timestamps = store_residuals.index
            
            plt.subplot(3, 1, i+1)
            # Plot ensemble residuals
            plt.plot(timestamps, store_residuals['sales'], 'r-', label='Ensemble Residuals')
            
            # Plot individual model residuals
            for j, model_name in enumerate(model_names):
                col_name = f'sales_{model_name}'
                if col_name in store_residuals.columns:
                    color_idx = min(j, len(colors)-1)
                    plt.plot(timestamps, store_residuals[col_name], colors[color_idx], label=f'{model_name} Residuals')
            
            plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
            plt.title(f'Store {store} - Residuals Comparison')
            plt.ylabel('Residual Value')
            plt.legend()
        
        plt.tight_layout()
        plt.savefig(f"{plot_path}/ensemble_regressor_residuals_comparison.png")
        print(f"Ensemble regressor residuals comparison plots saved to {plot_path}/ensemble_regressor_residuals_comparison.png")
        
        # Create a third figure that shows each model's residuals separately in individual subplots
        # This makes it easier to compare the patterns of each model with the ensemble
        for i, store in enumerate(stores[:3]):
            store_residuals = residuals.loc[store]
            timestamps = store_residuals.index
            
            # Count how many model residuals we have to determine grid layout
            model_cols = [col for col in store_residuals.columns if col.startswith('sales_residual_')]
            num_models = len(model_cols) + 1  # +1 for ensemble
            
            # Create a figure with one subplot per model plus ensemble
            plt.figure(figsize=(15, 3 * num_models))
            
            # Plot ensemble residuals in the first subplot
            plt.subplot(num_models, 1, 1)
            plt.plot(timestamps, store_residuals['sales'], 'r-', label='Ensemble Residuals')
            plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
            plt.title(f'Store {store} - Ensemble Residuals')
            plt.ylabel('Residual Value')
            plt.legend()
            
            # Plot each individual model's residuals in separate subplots
            for j, model_name in enumerate(model_names):
                col_name = f'sales_residual_{model_name}'
                if col_name in store_residuals.columns:
                    plt.subplot(num_models, 1, j+2)  # +2 because the first subplot is for ensemble
                    plt.plot(timestamps, store_residuals[col_name], colors[min(j, len(colors)-1)], label=f'{model_name} Residuals')
                    # Also plot ensemble for comparison
                    plt.plot(timestamps, store_residuals['sales'], 'r--', alpha=0.5, label='Ensemble Residuals')
                    plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
                    plt.title(f'Store {store} - {model_name} Residuals vs Ensemble')
                    plt.ylabel('Residual Value')
                    plt.legend()
            
            plt.tight_layout()
            plt.savefig(f"{plot_path}/store_{store}_model_comparison.png")
            print(f"Store {store} model comparison plots saved to {plot_path}/store_{store}_model_comparison.png")
        # Create a fourth figure that plots all three stores together for each model
        # This allows comparing how each model performs across different stores
        
        # First, identify all models (including ensemble)
        all_models = ['Ensemble'] + model_names
        
        # Create figure with one subplot per model
        plt.figure(figsize=(15, 4 * len(all_models)))
        
        # Store colors for each store
        store_colors = ['b-', 'g-', 'r-', 'c-', 'm-']
        
        # For each model (including ensemble), create a subplot showing all stores
        for j, model_name in enumerate(all_models):
            plt.subplot(len(all_models), 1, j+1)
            
            # Plot zero reference line
            plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
            
            # For each store, plot the appropriate residuals
            for i, store in enumerate(stores[:3]):
                store_residuals = residuals.loc[store]
                timestamps = store_residuals.index
                
                # Determine the correct column name based on whether this is ensemble or a specific model
                if model_name == 'Ensemble':
                    values = store_residuals['sales']
                    model_label = 'Ensemble'
                else:
                    col_name = f'sales_residual_{model_name}'
                    if col_name not in store_residuals.columns:
                        continue
                    values = store_residuals[col_name]
                    model_label = model_name
                
                # Plot this store's residuals for the current model
                plt.plot(timestamps, values, store_colors[i], label=f'Store {store}')
            
            plt.title(f'{model_label} Residuals - All Stores')
            plt.ylabel('Residual Value')
            plt.legend()
        
        plt.tight_layout()
        plt.savefig(f"{plot_path}/all_stores_by_model_comparison.png")
        print(f"All stores by model comparison plots saved to {plot_path}/all_stores_by_model_comparison.png")
    return residuals

def main():
    sales_data = pd.read_csv("tests/data/simulated_sales_data.csv")
    static_features = pd.read_csv("tests/data/simulated_static_features.csv")

    # Create TimeSeriesDataFrame
    ts_data = TimeSeriesDataFrame(
        sales_data,
        id_column="item_id",
        timestamp_column="timestamp",
        static_features=static_features
    )

    # Display data sample
    print("\nGenerated data sample:")
    print(ts_data.head())
    
    # Display static features
    print("\nStatic features:")
    print(ts_data.static_features)
    
    # Test the ensemble regressor
    print("\nTesting CrossSectionalRegressor...")
    residuals = test_ensemble_regressor(ts_data, plot_results=True)
    
    print("\nResiduals sample:")
    print(residuals.head())
    
    print("\nTest completed successfully!")

if __name__ == "__main__":
    main() 