import random
import pandas as pd
import numpy as np
from datetime import datetime
from autogluon.timeseries import TimeSeriesDataFrame
import dataclasses
import os
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

@dataclass(frozen=True)
class StoreParams:
    """Parameters defining sales patterns for a store.
    
    Args:
        amplitude: Height of the periodic sales pattern
        frequency: Frequency of the sales cycle in days
        baseline: Base sales level
        promotion_effect: Sales lift when promotion is active
        temperature_effect: Sales change per degree Celsius
        price_effect: Sales change per unit price increase
        promotion_probability: Probability of having a promotion on any given day
        temperature_noise: Standard deviation of temperature noise
        sales_noise: Standard deviation of random noise added to sales
    """
    amplitude: float
    frequency: int
    baseline: float
    promotion_effect: float
    temperature_effect: float
    price_effect: float
    promotion_probability: float
    temperature_noise: float  
    sales_noise: float


def generate_store_params(num_stores: int = 200) -> Dict[str, StoreParams]:
    """Generate parameters for stores with IID random distributions.
    
    Args:
        num_stores: Number of stores to generate parameters for
        
    Returns:
        Dictionary mapping store IDs to their sales pattern parameters
    """
    store_params = {}
    
    # Set random parameter distributions
    amplitude_mean, amplitude_std = 80, 30
    frequency_options = [7, 14, 30, 90]  # weekly, bi-weekly, monthly, quarterly
    baseline_mean, baseline_std = 200, 80
    promotion_effect_mean, promotion_effect_std = 30, 10
    temperature_effect_mean, temperature_effect_std = 0.3, 0.1
    price_effect_mean, price_effect_std = -12, 5
    promotion_probability_mean, promotion_probability_std = 0.2, 0.05
    temperature_noise_mean, temperature_noise_std = 2.0, 0.5
    sales_noise_mean, sales_noise_std = 10.0, 3.0
    
    for i in range(1, num_stores + 1):
        store_id = f'Store_{i}'
        
        # Generate random parameters from normal distributions, with appropriate constraints
        amplitude = max(10, np.random.normal(amplitude_mean, amplitude_std))
        frequency = np.random.choice(frequency_options)
        baseline = max(50, np.random.normal(baseline_mean, baseline_std))
        promotion_effect = max(0, np.random.normal(promotion_effect_mean, promotion_effect_std))
        temperature_effect = np.random.normal(temperature_effect_mean, temperature_effect_std)
        price_effect = min(0, np.random.normal(price_effect_mean, price_effect_std))
        promotion_probability = max(0.05, min(0.4, np.random.normal(promotion_probability_mean, promotion_probability_std)))
        temperature_noise = max(0.5, np.random.normal(temperature_noise_mean, temperature_noise_std))
        sales_noise = max(2.0, np.random.normal(sales_noise_mean, sales_noise_std))
        
        store_params[store_id] = StoreParams(
            amplitude=round(amplitude, 1),
            frequency=int(frequency),
            baseline=round(baseline, 1),
            promotion_effect=round(promotion_effect, 1),
            temperature_effect=round(temperature_effect, 2),
            price_effect=round(price_effect, 1),
            promotion_probability=round(promotion_probability, 2),
            temperature_noise=round(temperature_noise, 1),
            sales_noise=round(sales_noise, 1)
        )
    
    return store_params


def generate_date_range(start_year: int = 2020, num_years: int = 3) -> pd.DatetimeIndex:
    """Generate daily dates for multiple years.
    
    Args:
        start_year: The starting year for the date range
        num_years: Number of years to generate
        
    Returns:
        DatetimeIndex containing daily dates for the specified years
    """
    start_date = datetime(start_year, 1, 1)
    end_date = datetime(start_year + num_years - 1, 12, 31)
    return pd.date_range(start=start_date, end=end_date, freq='D')


def generate_store_sales(store_params: Dict[str, StoreParams], 
                        dates: pd.DatetimeIndex) -> pd.DataFrame:
    """Generate synthetic sales data with periodic patterns and covariates.
    
    Args:
        store_params: Dictionary mapping store IDs to their parameters
        dates: DatetimeIndex of dates to generate data for
        
    Returns:
        DataFrame containing sales data for all stores
    """
    all_data = []
    stores = list(store_params.keys())
    
    for store in stores:
        # Get parameters for this store
        params = store_params[store]
        store_data = []
        
        for t, date in enumerate(dates):
            # Generate covariates
            # 1. Promotion (binary): occasional promotions with store-specific probability
            promotion = 1 if np.random.random() < params.promotion_probability else 0
            
            # 2. Temperature (continuous): seasonal pattern + store-specific noise
            day_of_year = date.day_of_year
            temperature = 20 + 15 * np.sin(2 * np.pi * day_of_year / 365) + np.random.normal(0, params.temperature_noise)
            
            # 3. Price (continuous): fixed price for now
            price = 10
            
            # Generate periodic component (sine wave)
            periodic = params.amplitude * np.sin(2 * np.pi * t / params.frequency)
            
            # Linear combination of covariates
            covariate_effect = (
                promotion * params.promotion_effect + 
                temperature * params.temperature_effect + 
                price * params.price_effect
            )
            
            # Calculate target sales with store-specific noise level
            sales = max(0, params.baseline + periodic + covariate_effect + np.random.normal(0, params.sales_noise))

            # Create row
            row = {
                'timestamp': date,
                'series_id': store,
                'target': round(sales, 2),
                'promotion': promotion,
                'temperature': round(temperature, 2),
                'price': round(price, 2)
            }
            store_data.append(row)
        
        all_data.extend(store_data)
    
    return pd.DataFrame(all_data)


def create_time_series_df(sales_df: pd.DataFrame) -> TimeSeriesDataFrame:
    """Convert sales DataFrame to AutoGluon TimeSeriesDataFrame.
    
    Args:
        sales_df: DataFrame containing sales data
        
    Returns:
        AutoGluon TimeSeriesDataFrame for time series modeling
    """
    return TimeSeriesDataFrame.from_data_frame(
        df=sales_df,
        id_column='series_id',
        timestamp_column='timestamp'
    )


def print_statistics(tsdf: TimeSeriesDataFrame) -> None:
    """Print summary statistics about the time series data.
    
    Args:
        tsdf: TimeSeriesDataFrame to analyze
    """
    print(f"TimeSeriesDataFrame shape: {tsdf.shape}")
    print(f"Number of time series: {len(tsdf.item_ids)}")
    print(f"Available features: {tsdf.columns.tolist()}")
    print("\nSample of the data:")
    print(tsdf.head(10))

    # Basic statistics by store
    print("\nSales Statistics by Store:")
    for store in tsdf.item_ids:
        store_data = tsdf.loc[store]
        print(f"\n{store}:")
        print(f"Mean: {store_data['target'].mean():.2f}")
        print(f"Min: {store_data['target'].min():.2f}")
        print(f"Max: {store_data['target'].max():.2f}")
        print(f"Std Dev: {store_data['target'].std():.2f}")


def save_data(tsdf: TimeSeriesDataFrame, data_dir: str = './tests/hopformer/data') -> None:
    """Save the time series data to CSV and pickle formats.
    
    Args:
        tsdf: TimeSeriesDataFrame to save
        data_dir: Directory to save data files
        
    Raises:
        OSError: If directory creation fails
    """
    try:
        # Ensure the directory exists
        os.makedirs(data_dir, exist_ok=True)
        
        # Define file paths
        csv_path = os.path.join(data_dir, 'store_sales_data.csv')
        pkl_path = os.path.join(data_dir, 'store_sales_data.pkl')
        
        # Save files
        tsdf.to_csv(csv_path)
        tsdf.to_pickle(pkl_path)
        print(f"\nData saved to '{csv_path}' and '{pkl_path}'")
    except OSError as err:
        raise OSError(f"Failed to create directory or save files: {err}") from err


def plot_stores_sample(tsdf: TimeSeriesDataFrame, 
                      store_params: Dict[str, StoreParams],
                      plot_dir: str = './tests/hopformer/plots',
                      sample_size: int = 4) -> None:
    """Plot sales data and covariates for a sample of stores.
    
    Args:
        tsdf: TimeSeriesDataFrame containing store data
        store_params: Dictionary mapping store IDs to their parameters
        plot_dir: Directory to save the visualization
        sample_size: Number of stores to include in the visualization
        
    Raises:
        ImportError: If matplotlib is not available
        OSError: If directory creation fails
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib as mpl
        
        # Ensure the plot directory exists
        os.makedirs(plot_dir, exist_ok=True)
        
        # Set publication-quality settings
        plt.rcParams.update({
            # Use serif font for publication
            'font.family': 'serif',
            'font.serif': ['Times New Roman', 'DejaVu Serif', 'Computer Modern Roman'],
            'font.size': 9,
            'axes.labelsize': 10,
            'axes.titlesize': 11,
            'xtick.labelsize': 8,
            'ytick.labelsize': 8,
            'legend.fontsize': 8,
            'figure.titlesize': 12,
            
            # Remove top and right spines
            'axes.spines.top': False,
            'axes.spines.right': False,
            
            # Set ticks inward
            'xtick.direction': 'in',
            'ytick.direction': 'in',
            
            # Line widths
            'lines.linewidth': 1.5,
            'axes.linewidth': 0.8,
            'grid.linewidth': 0.5,
            
            # Figure size - optimized for two-column layout
            'figure.figsize': (7.2, 4.45),  # Optimal for two-column, fits full page width
            'figure.dpi': 300,
        })
        
        # Set up a color-blind safe palette (Tableau 10)
        tableau10 = ['#4e79a7', '#f28e2c', '#e15759', '#76b7b2', 
                     '#59a14f', '#edc949', '#af7aa1', '#ff9da7', 
                     '#9c755f', '#bab0ab']
        
        # Randomly sample stores for visualization
        all_stores = list(store_params.keys())
        sample_stores = sorted(random.sample(all_stores, min(sample_size, len(all_stores))))
        
        # Create output paths
        sales_path = os.path.join(plot_dir, 'store_sales_sample.png')
        covariates_path = os.path.join(plot_dir, 'store_covariates_sample.png')
        params_path = os.path.join(plot_dir, 'store_params_distribution.png')
        
        # Create sales figure - 2×2 layout
        fig1, axes1 = plt.subplots(2, 2, figsize=(7.2, 4.45), constrained_layout=True)
        axes1 = axes1.flatten()
        
        # Create covariates figure - 2×2 layout
        fig2, axes2 = plt.subplots(2, 2, figsize=(7.2, 4.45), constrained_layout=True)
        axes2 = axes2.flatten()
        
        for i, store in enumerate(sample_stores):
            params = store_params[store]
            store_data = tsdf.loc[store]
            
            # Sales plot
            ax_sales = axes1[i]
            ax_sales.plot(store_data.index, store_data['target'], 
                         color=tableau10[0], linewidth=1.5)
            
            # Format x-axis for quarterly ticks
            ax_sales.xaxis.set_major_locator(mpl.dates.MonthLocator([1, 7]))
            ax_sales.xaxis.set_major_formatter(mpl.dates.DateFormatter('%b %Y'))
            
            # Add title with store name
            ax_sales.set_title(f'{store}', fontsize=10)
            
            # Add y-label (sales)
            if i % 2 == 0:  # Left column
                ax_sales.set_ylabel('Sales', fontsize=9)
            
            # Add x-label (date) for bottom row
            if i >= 2:  # Bottom row
                ax_sales.set_xlabel('Date', fontsize=9)
            
            # Add subtle gridlines
            ax_sales.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
            
            # Add annotation with parameters
            param_text = (f"B:{params.baseline}, A:{params.amplitude}, F:{params.frequency}d\n"
                         f"Noise:{params.sales_noise:.1f}")
            ax_sales.annotate(param_text, xy=(0.05, 0.95), xycoords='axes fraction',
                             ha='left', va='top', fontsize=7, alpha=0.7)
            
            # Covariates plot
            ax_cov = axes2[i]
            
            # Plot temperature
            temp_line = ax_cov.plot(store_data.index, store_data['temperature'], 
                                   color=tableau10[1], alpha=0.7, linewidth=1.5, 
                                   label='Temp')
            
            # Format x-axis for quarterly ticks
            ax_cov.xaxis.set_major_locator(mpl.dates.MonthLocator([1, 7]))
            ax_cov.xaxis.set_major_formatter(mpl.dates.DateFormatter('%b %Y'))
            
            # Add title
            ax_cov.set_title(f'{store} Covariates', fontsize=10)
            
            # Add y-label for left column
            if i % 2 == 0:  # Left column
                ax_cov.set_ylabel('Value', fontsize=9)
            
            # Add x-label for bottom row
            if i >= 2:  # Bottom row
                ax_cov.set_xlabel('Date', fontsize=9)
            
            # Add subtle gridlines
            ax_cov.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
            
            # Add promotion markers
            promotion_dates = store_data[store_data['promotion'] == 1].index
            ax_cov.scatter(promotion_dates, np.zeros(len(promotion_dates)), 
                         marker='|', s=80, color=tableau10[2], alpha=0.8)
            
            # Add parameter annotation
            effect_text = (f"Temp:{params.temperature_effect:.2f}, "
                         f"Promo:{params.promotion_effect:.1f}, "
                         f"P(promo):{params.promotion_probability:.2f}")
            ax_cov.annotate(effect_text, xy=(0.05, 0.95), xycoords='axes fraction',
                           ha='left', va='top', fontsize=7, alpha=0.7)
        
        # Create a shared legend for the second figure
        from matplotlib.lines import Line2D
        
        legend_elements = [
            Line2D([0], [0], color=tableau10[1], lw=1.5, alpha=0.7, label='Temperature'),
            Line2D([0], [0], marker='|', color=tableau10[2], markersize=8, 
                 linestyle='None', alpha=0.8, label='Promotion')
        ]
        
        fig2.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 0.01),
                   ncol=2, frameon=False, fontsize=8)
        
        # Create a figure to show parameter distributions
        fig3, axes3 = plt.subplots(3, 3, figsize=(7.2, 6.5), constrained_layout=True)
        
        # Extract parameters
        amplitudes = [p.amplitude for p in store_params.values()]
        frequencies = [p.frequency for p in store_params.values()]
        baselines = [p.baseline for p in store_params.values()]
        promotion_effects = [p.promotion_effect for p in store_params.values()]
        temperature_effects = [p.temperature_effect for p in store_params.values()]
        price_effects = [p.price_effect for p in store_params.values()]
        promotion_probs = [p.promotion_probability for p in store_params.values()]
        temp_noises = [p.temperature_noise for p in store_params.values()]
        sales_noises = [p.sales_noise for p in store_params.values()]
        
        # Plot histograms of parameters
        axes3[0, 0].hist(amplitudes, bins=15, color=tableau10[0], alpha=0.7)
        axes3[0, 0].set_title('Amplitude')
        
        axes3[0, 1].hist(frequencies, bins=len(set(frequencies)), color=tableau10[1], alpha=0.7)
        axes3[0, 1].set_title('Frequency (days)')
        
        axes3[0, 2].hist(baselines, bins=15, color=tableau10[2], alpha=0.7)
        axes3[0, 2].set_title('Baseline')
        
        axes3[1, 0].hist(promotion_effects, bins=15, color=tableau10[3], alpha=0.7)
        axes3[1, 0].set_title('Promotion Effect')
        
        axes3[1, 1].hist(temperature_effects, bins=15, color=tableau10[4], alpha=0.7)
        axes3[1, 1].set_title('Temperature Effect')
        
        axes3[1, 2].hist(price_effects, bins=15, color=tableau10[5], alpha=0.7)
        axes3[1, 2].set_title('Price Effect')
        
        axes3[2, 0].hist(promotion_probs, bins=15, color=tableau10[6], alpha=0.7)
        axes3[2, 0].set_title('Promotion Probability')
        
        axes3[2, 1].hist(temp_noises, bins=15, color=tableau10[7], alpha=0.7)
        axes3[2, 1].set_title('Temperature Noise')
        
        axes3[2, 2].hist(sales_noises, bins=15, color=tableau10[8], alpha=0.7)
        axes3[2, 2].set_title('Sales Noise')
        
        # Format histogram axes
        for ax in axes3.flatten():
            ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        
        # Save all figures
        fig1.savefig(sales_path, dpi=300, bbox_inches='tight')
        fig2.savefig(covariates_path, dpi=300, bbox_inches='tight')
        fig3.savefig(params_path, dpi=300, bbox_inches='tight')
        
        print(f"Visualizations saved to '{sales_path}', '{covariates_path}', and '{params_path}'")
    
    except ImportError:
        print("Matplotlib not available for visualization")
    except OSError as err:
        raise OSError(f"Failed to create directory or save plot: {err}") from err


def main(data_dir: str = './tests/hopformer/data', 
         plot_dir: str = './tests/hopformer/plots',
         num_stores: int = 200,
         num_years: int = 2):
    """Generate synthetic store sales data, analyze, and visualize it.
    
    Args:
        data_dir: Directory to save data files
        plot_dir: Directory to save plot files
        num_stores: Number of stores to generate
        num_years: Number of years of data to generate
        
    Raises:
        OSError: If directory creation fails
    """
    # Generate store parameters
    store_params = generate_store_params(num_stores)
    
    # Generate date range
    dates = generate_date_range(start_year=2020, num_years=num_years)
    
    # Generate sales data
    print(f"Generating data for {num_stores} stores over {num_years} years...")
    sales_df = generate_store_sales(store_params, dates)
    
    # Create TimeSeriesDataFrame
    tsdf = create_time_series_df(sales_df)
    
    # Print statistics
    print_statistics(tsdf)
    
    # Save data
    save_data(tsdf, data_dir)
    
    # Plot a sample of the data
    plot_stores_sample(tsdf, store_params, plot_dir, sample_size=4)


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Generate synthetic store sales data')
    parser.add_argument('--data-dir', type=str, default='./data/sales1/data',
                        help='Directory to save data files')
    parser.add_argument('--plot-dir', type=str, default='./data/sales1/plots',
                        help='Directory to save plot files')
    parser.add_argument('--num-stores', type=int, default=200,
                        help='Number of stores to generate')
    parser.add_argument('--num-years', type=int, default=3,
                        help='Number of years of data to generate')
    
    args = parser.parse_args()
    
    main(
        data_dir=args.data_dir, 
        plot_dir=args.plot_dir,
        num_stores=args.num_stores,
        num_years=args.num_years
    )
