import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import random
from pathlib import Path
import os
import torch
import argparse

from autogluon.timeseries import TimeSeriesDataFrame
from autogluon.timeseries.utils.features import CovariateMetadata
from residual_chronos.Regressor import CrossSectionalRegressor

# Set random seeds for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)


def plot_residuals(residuals, model_names, target_name, plot_dir="tests/results/plots", max_stores=3):
    """
    Create comprehensive plots for model residuals comparison.
    
    Parameters
    ----------
    residuals : TimeSeriesDataFrame
        The residuals dataframe from the ensemble regressor fit_transform
    model_names : list of str
        Names of the models used in the ensemble
    target_name : str
        Name of the target column (e.g., 'unit_sales')
    plot_dir : str, optional
        Directory to save the plots, by default "tests/results/plots"
    max_stores : int, optional
        Maximum number of stores to plot, by default 3
    """
    # Create plot directory if it doesn't exist
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)
    
    # Get unique store IDs
    stores = residuals.index.unique(level='item_id').to_list()
    stores = stores[:max_stores]  # Limit to specified number of stores
    
    # Define colors for different models
    colors = ['g-', 'c-', 'm-', 'y-', 'k-', 'b--', 'g--', 'c--', 'm--']
    
    # Plot 1: Original data vs. residuals and predictions
    plt.figure(figsize=(18, 15))
    
    # Plot for each store
    for i, store in enumerate(stores):
        # Extract data for this store
        store_residuals = residuals.loc[store]
        
        # Get timestamps for plotting
        timestamps = store_residuals.index
        
        # 1. Plot 1: Original Sales vs Ensemble Residuals
        plt.subplot(len(stores), 3, i*3+1)
        plt.plot(timestamps, store_residuals[f"{target_name}_label"], 'b-', label='Original Sales')
        plt.plot(timestamps, store_residuals[target_name], 'r-', label='Ensemble Residuals')
        plt.title(f'Store {store} - Original vs Ensemble Residuals')
        plt.ylabel('Value')
        plt.legend()
        
        # 2. Plot 2: All Model Residuals Comparison
        plt.subplot(len(stores), 3, i*3+2)
        # Plot ensemble residuals first
        plt.plot(timestamps, store_residuals[target_name], 'r-', label='Ensemble')
        
        # Plot individual model residuals
        for j, model_name in enumerate(model_names):
            col_name = f'{target_name}_residual_{model_name}'
            if col_name in store_residuals.columns:
                color_idx = min(j, len(colors)-1)  # Ensure we don't exceed color list
                plt.plot(timestamps, store_residuals[col_name], colors[color_idx], label=f'{model_name}')
        
        plt.title(f'Store {store} - Residuals by Model')
        plt.ylabel('Residual Value')
        plt.legend()
        
        # 3. Plot 3: Reconstructed Predictions (residuals + original)
        plt.subplot(len(stores), 3, i*3+3)
        original = store_residuals[f"{target_name}_label"]
        plt.plot(timestamps, original, 'b-', label='Original Sales')
        plt.plot(timestamps, original + store_residuals[target_name], 'r-', label='Ensemble Prediction')
        
        # Plot individual model predictions
        for j, model_name in enumerate(model_names):
            col_name = f'{target_name}_residual_{model_name}'
            if col_name in store_residuals.columns:
                color_idx = min(j, len(colors)-1)
                plt.plot(timestamps, original + store_residuals[col_name], colors[color_idx], label=f'{model_name} Prediction')
        
        plt.title(f'Store {store} - Reconstructed Predictions')
        plt.ylabel('Sales Value')
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(f"{plot_dir}/ensemble_regressor_detailed_plots.png")
    print(f"Saved plot to {plot_dir}/ensemble_regressor_detailed_plots.png")
    
    # Plot 2: Residuals comparison
    plt.figure(figsize=(15, 4 * len(stores)))
    for i, store in enumerate(stores):
        store_residuals = residuals.loc[store]
        timestamps = store_residuals.index
        
        plt.subplot(len(stores), 1, i+1)
        # Plot ensemble residuals
        plt.plot(timestamps, store_residuals[target_name], 'r-', label='Ensemble Residuals')
        
        # Plot individual model residuals
        for j, model_name in enumerate(model_names):
            col_name = f'{target_name}_residual_{model_name}'
            if col_name in store_residuals.columns:
                color_idx = min(j, len(colors)-1)
                plt.plot(timestamps, store_residuals[col_name], colors[color_idx], label=f'{model_name} Residuals')
        
        plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
        plt.title(f'Store {store} - Residuals Comparison')
        plt.ylabel('Residual Value')
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(f"{plot_dir}/ensemble_regressor_residuals_comparison.png")
    print(f"Saved plot to {plot_dir}/ensemble_regressor_residuals_comparison.png")
    # Plot 3: Per-store model comparison
    for store in stores:
        store_residuals = residuals.loc[store]
        timestamps = store_residuals.index
        
        # Count how many model residuals we have to determine grid layout
        model_cols = [col for col in store_residuals.columns if col.startswith(f'{target_name}_residual_')]
        num_models = len(model_cols) + 1  # +1 for ensemble
        
        # Create a figure with one subplot per model plus ensemble
        plt.figure(figsize=(15, 3 * num_models))
        
        # Plot ensemble residuals in the first subplot
        plt.subplot(num_models, 1, 1)
        plt.plot(timestamps, store_residuals[target_name], 'r-', label='Ensemble Residuals')
        plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
        plt.title(f'Store {store} - Ensemble Residuals')
        plt.ylabel('Residual Value')
        plt.legend()
        
        # Plot each individual model's residuals in separate subplots
        for j, model_name in enumerate(model_names):
            col_name = f'{target_name}_residual_{model_name}'
            if col_name in store_residuals.columns:
                plt.subplot(num_models, 1, j+2)  # +2 because the first subplot is for ensemble
                plt.plot(timestamps, store_residuals[col_name], colors[min(j, len(colors)-1)], label=f'{model_name} Residuals')
                # Also plot ensemble for comparison
                plt.plot(timestamps, store_residuals[target_name], 'r--', alpha=0.5, label='Ensemble Residuals')
                plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
                plt.title(f'Store {store} - {model_name} Residuals vs Ensemble')
                plt.ylabel('Residual Value')
                plt.legend()
        
        plt.tight_layout()
        plt.savefig(f"{plot_dir}/store_{store}_model_comparison.png")
        print(f"Saved plot to {plot_dir}/store_{store}_model_comparison.png")
    # Plot 4: All stores by model comparison
    # First, identify all models (including ensemble)
    all_models = ['Ensemble'] + model_names
    
    # Create figure with one subplot per model
    plt.figure(figsize=(15, 4 * len(all_models)))
    
    # Store colors for each store
    store_colors = ['b-', 'g-', 'r-', 'c-', 'm-']
    
    # For each model (including ensemble), create a subplot showing all stores
    for j, model_name in enumerate(all_models):
        plt.subplot(len(all_models), 1, j+1)
        
        # Plot zero reference line
        plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
        
        # For each store, plot the appropriate residuals
        for i, store in enumerate(stores):
            store_residuals = residuals.loc[store]
            timestamps = store_residuals.index
            
            # Determine the correct column name based on whether this is ensemble or a specific model
            if model_name == 'Ensemble':
                values = store_residuals[target_name]
                model_label = 'Ensemble'
            else:
                col_name = f'{target_name}_residual_{model_name}'
                if col_name not in store_residuals.columns:
                    continue
                values = store_residuals[col_name]
                model_label = model_name
            
            # Plot this store's residuals for the current model
            plt.plot(timestamps, values, store_colors[i], label=f'Store {store}')
        
        plt.title(f'{model_label} Residuals - All Stores')
        plt.ylabel('Residual Value')
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(f"{plot_dir}/all_stores_by_model_comparison.png")
    print(f"Saved plot to {plot_dir}/all_stores_by_model_comparison.png")

def test_ensemble_regressor(data, plot_results=True, plot_path="tests/results/plots"):
    """Test the CrossSectionalRegressor with the generated data."""
    if not os.path.exists(plot_path):
        os.makedirs(plot_path)
        
    # Extract the target and other covariates
    target = 'unit_sales'
    
    # Define covariate metadata
    covariate_metadata = CovariateMetadata(
        # static_features_cat=['store_type'],
        # static_features_real=['store_size'],
        known_covariates_real=['scaled_price', 'promotion_email', 'promotion_homepage'],
        # known_covariates_cat=['promotion_email', 'promotion_homepage']
    )
    
    # Create an ensemble regressor
    model_names = ["XGB", "RF", "CAT", "GBM"]
    ensemble_regressor = CrossSectionalRegressor(
        model_names=model_names,
        target=target,
        covariate_metadata=covariate_metadata,
        include_static_features=True,
        include_item_id=True,
        models_hyperparameters={
            "XGB": {"learning_rate": 0.1, "max_depth": 5},
            "RF": {"n_estimators": 100},
            "CAT": {"iterations": 100, "depth": 5},
            "GBM": {"learning_rate": 0.1, "max_depth": 5},
        },
        fit_time_fraction=0.5,
        validation_fraction=0.1,
        eval_metric="mean_absolute_error",
        aggregation_strategy="equal"
    )
    
    # Fit and transform the data to get residuals
    residuals = ensemble_regressor.fit_transform(data, keep_target_column=True, include_individual_residuals=True)
    
    if plot_results:
        # Call the residuals plotting function
        plot_residuals(
            residuals=residuals,
            model_names=model_names,
            target_name=target,
            plot_dir=plot_path
        )

    return residuals

def main():
    ts_data = TimeSeriesDataFrame.from_path(
        "https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/test.csv",
    )

    # Display data sample
    print("\nGenerated data sample:")
    print(ts_data.head())
    
    # Display static features
    print("\nStatic features:")
    print(ts_data.static_features)
    
    # Test the ensemble regressor
    print("\nTesting CrossSectionalRegressor...")
    residuals = test_ensemble_regressor(ts_data, plot_results=True)
    
    print("\nResiduals sample:")
    print(residuals.head())
    
    print("\nTest completed successfully!")

if __name__ == "__main__":
    main() 