import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import random
from pathlib import Path
import os
import torch
import argparse

from autogluon.timeseries import TimeSeriesDataFrame
from autogluon.timeseries.utils.features import CovariateMetadata
from residual_chronos.Regressor import CrossSectionalRegressor

# Set random seeds for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

def generate_dates(start_date, periods):
    """Generate a list of consecutive dates."""
    return [start_date + timedelta(days=i) for i in range(periods)]

def generate_base_sales(periods, trend_coef=0.1, seasonal_amp=10, seasonal_freq=7):
    """
    Generate baseline sales with trend and seasonality.
    
    Parameters:
    -----------
    periods : int
        Number of time periods to generate
    trend_coef : float
        Coefficient for linear trend
    seasonal_amp : float
        Amplitude of seasonal component
    seasonal_freq : int
        Frequency of seasonal component (e.g., 7 for weekly)
        
    Returns:
    --------
    numpy.ndarray
        Array of baseline sales values
    """
    # Linear trend
    trend = np.arange(periods) * trend_coef
    
    # Seasonal component (weekly pattern)
    seasonal = seasonal_amp * np.sin(2 * np.pi * np.arange(periods) / seasonal_freq)
    
    # Combine trend and seasonality
    base_sales = 100 + trend + seasonal
    
    return base_sales

def generate_promotion_effect(periods, promotion_prob=0.1, effect_mean=15, effect_std=3):
    """
    Generate binary promotion indicators and their effect on sales.
    
    Parameters:
    -----------
    periods : int
        Number of time periods to generate
    promotion_prob : float
        Probability of a promotion on any given day
    effect_mean : float
        Mean effect of promotion on sales
    effect_std : float
        Standard deviation of promotion effect
        
    Returns:
    --------
    tuple
        (promotion_indicator, promotion_effect)
    """
    # Generate binary promotion indicators
    promotions = np.random.binomial(1, promotion_prob, periods)
    
    # Generate promotion effects (only applied when promotion=1)
    effects = np.random.normal(effect_mean, effect_std, periods) * promotions
    
    return promotions, effects

def generate_temperature_effect(periods, temp_min=10, temp_max=35, effect_coef=-0.5):
    """
    Generate temperature values and their effect on sales.
    Higher temperatures typically reduce sales (negative effect).
    
    Parameters:
    -----------
    periods : int
        Number of time periods to generate
    temp_min : float
        Minimum temperature
    temp_max : float
        Maximum temperature
    effect_coef : float
        Effect coefficient (negative means higher temps reduce sales)
        
    Returns:
    --------
    tuple
        (temperature_values, temperature_effect)
    """
    # Generate temperatures with seasonal pattern
    temps = temp_min + (temp_max - temp_min) * (0.5 + 0.5 * np.sin(2 * np.pi * np.arange(periods) / 365))
    
    # Add some noise to temperatures
    temps += np.random.normal(0, 2, periods)
    temps = np.clip(temps, temp_min, temp_max)
    
    # Calculate temperature effect
    # Centered around 22.5°C (most optimal temperature)
    temp_effect = effect_coef * (temps - 22.5)**2
    
    return temps, temp_effect

def generate_price_effect(periods, base_price=10, price_change_prob=0.05, price_effect_coef=-5):
    """
    Generate price changes and their effect on sales.
    Higher prices typically reduce sales (negative effect).
    
    Parameters:
    -----------
    periods : int
        Number of time periods to generate
    base_price : float
        Base price
    price_change_prob : float
        Probability of price change on any given day
    price_effect_coef : float
        Price elasticity coefficient
        
    Returns:
    --------
    tuple
        (price_values, price_effect)
    """
    # Initialize price at base price
    prices = np.ones(periods) * base_price
    
    # Generate price changes
    for i in range(1, periods):
        if np.random.random() < price_change_prob:
            # Change price by -15% to +10%
            change = np.random.uniform(-0.15, 0.1)
            prices[i:] = prices[i-1] * (1 + change)
    
    # Ensure prices are reasonable
    prices = np.clip(prices, base_price * 0.7, base_price * 1.3)
    
    # Calculate price effect
    # Centered around base_price (reference price)
    price_effect = price_effect_coef * (prices - base_price)
    
    return prices, price_effect

def generate_store_type_effect(store_type):
    """
    Define store type effect on baseline sales.
    
    Parameters:
    -----------
    store_type : str
        Type of store
        
    Returns:
    --------
    float
        Effect on baseline sales
    """
    effects = {
        'urban': 20,
        'suburban': 0,
        'rural': -15
    }
    return effects.get(store_type, 0)

def generate_store_size_effect(store_size):
    """
    Define store size effect on baseline sales.
    
    Parameters:
    -----------
    store_size : float
        Size of store in 1000 sq ft
        
    Returns:
    --------
    float
        Effect on baseline sales
    """
    # Larger stores have higher baseline sales
    return store_size * 2  # 2 units per 1000 sq ft

def generate_multi_store_sales_data(n_stores=5, periods=365, plot=False, plot_dir=None):
    """
    Generate sales data for multiple stores with different characteristics.
    
    Parameters:
    -----------
    n_stores : int
        Number of stores to generate
    periods : int
        Number of time periods to generate
    plot : bool
        Whether to plot the data generation process
        
    Returns:
    --------
    tuple
        (TimeSeriesDataFrame, static_features_df)
    """
    # Define store characteristics
    store_types = ['urban', 'suburban', 'rural']
    
    # Create list to store data for each store
    all_data = []
    
    # Create dictionary for static features
    static_features = {'store_size': [], 'store_type': []}
    
    if plot:
        # Create a figure for visualization
        plt.figure(figsize=(15, n_stores * 5))
    
    # Start date for the time series
    start_date = datetime(2022, 1, 1)
    dates = generate_dates(start_date, periods)
    
    # Generate data for each store
    for store_id in range(1, n_stores + 1):
        # Assign store characteristics
        store_type = random.choice(store_types)
        store_size = random.uniform(20, 100)  # in 1000 sq ft
        
        # Store static features
        static_features['store_size'].append(store_size)
        static_features['store_type'].append(store_type)
        
        # Generate baseline sales with store-specific parameters
        trend_coef = random.uniform(0.05, 0.2)  # Different trend for each store
        seasonal_amp = random.uniform(5, 15)     # Different seasonality amplitude
        seasonal_freq = 7  # Weekly seasonality
        
        base_sales = generate_base_sales(periods, trend_coef, seasonal_amp, seasonal_freq)
        
        # Add store type and size effects to baseline
        base_sales += generate_store_type_effect(store_type)
        base_sales += generate_store_size_effect(store_size)
        
        # Generate covariates and their effects
        promotions, promo_effect = generate_promotion_effect(periods)
        temperatures, temp_effect = generate_temperature_effect(periods)
        prices, price_effect = generate_price_effect(periods)
        
        # Add random noise
        # TODO adjust noise level
        noise = np.random.normal(0, 1, periods) # 5
        
        # Combine all effects
        sales = base_sales + promo_effect + temp_effect + price_effect + noise
        sales = np.maximum(sales, 0)  # Ensure non-negative sales
        
        # Create store data
        store_data = []
        for i in range(periods):
            store_data.append({
                'item_id': f'store_{store_id}',
                'timestamp': dates[i],
                'sales': sales[i],
                'promotion': promotions[i],
                'temperature': temperatures[i],
                'price': prices[i]
            })
        
        # Create a DataFrame and append to the list
        store_df = pd.DataFrame(store_data)
        all_data.append(store_df)
        
        if plot:
            # Plot the data generation process for this store
            plt_idx = (store_id - 1) * 2 + 1
            
            # Plot 1: Components of sales
            plt.subplot(n_stores, 2, plt_idx)
            plt.plot(dates, base_sales, label='Base Sales')
            plt.plot(dates, base_sales + promo_effect, label='+ Promotion Effect')
            plt.plot(dates, base_sales + promo_effect + temp_effect, label='+ Temperature Effect')
            plt.plot(dates, base_sales + promo_effect + temp_effect + price_effect, label='+ Price Effect')
            plt.plot(dates, sales, 'k-', alpha=0.7, label='Final Sales (with noise)')
            plt.title(f'Store {store_id} ({store_type}, {store_size:.1f}k sq ft) - Sales Components')
            plt.ylabel('Sales')
            plt.legend()
            
            # Plot 2: Covariates
            plt.subplot(n_stores, 2, plt_idx + 1)
            plt.plot(dates, promotions * 10, 'r-', label='Promotion (×10)')
            plt.plot(dates, temperatures, 'g-', label='Temperature (°C)')
            plt.plot(dates, prices, 'b-', label='Price ($)')
            plt.title(f'Store {store_id} - Covariates')
            plt.ylabel('Value')
            plt.legend()

            if plot_dir:
                plt.savefig(plot_dir / f'store_{store_id}.png')
    
    # Combine all store data
    combined_df = pd.concat(all_data, ignore_index=True)
    
    # Create a TimeSeriesDataFrame
    ts_df = TimeSeriesDataFrame(combined_df)
    
    # Create static features DataFrame
    static_features_df = pd.DataFrame({
        'store_size': static_features['store_size'],
        'store_type': static_features['store_type']
    }, index=[f'store_{i}' for i in range(1, n_stores + 1)])
    
    # Set categorical columns
    static_features_df['store_type'] = static_features_df['store_type'].astype('category')
    
    # Attach static features to the TimeSeriesDataFrame
    ts_df.static_features = static_features_df
    
    if plot:
        plt.tight_layout()
    
    return ts_df

def main():
    """Main function to run the test."""
    parser = argparse.ArgumentParser(description='Test CrossSectionalRegressor with simulated data')
    parser.add_argument('--n_stores', type=int, default=5, help='Number of stores to simulate')
    parser.add_argument('--periods', type=int, default=365, help='Number of days to simulate')
    parser.add_argument('--plot', action='store_true', help='Plot the data generation process')
    parser.add_argument('--output_dir', type=str, default=None, help='Directory to save the generated data')
    args = parser.parse_args()

    plot_dir = Path(args.output_dir) / "plots"
    if not plot_dir.exists():
        plot_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate data
    print(f"Generating sales data for {args.n_stores} stores over {args.periods} days...")
    data = generate_multi_store_sales_data(n_stores=args.n_stores, periods=args.periods, plot=args.plot, plot_dir=plot_dir)
    
    # Display data sample
    print("\nGenerated data sample:")
    print(data.head())
    
    # Display static features
    print("\nStatic features:")
    print(data.static_features)
    
    # Save data if output directory is specified
    if args.output_dir:
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Save the TimeSeriesDataFrame
        data.to_csv(output_dir / "simulated_sales_data.csv")
        
        # Save static features
        data.static_features.to_csv(output_dir / "simulated_static_features.csv")
        
        print(f"\nData saved to {output_dir}")
    
    # Test the ensemble regressor
    print("\nTesting CrossSectionalRegressor...")
    # residuals = test_ensemble_regressor(data, plot_results=args.plot)
    
    print("\nResiduals sample:")
    # print(residuals.head())
    
    print("\nTest completed successfully!")

if __name__ == "__main__":
    main() 