import random
import pandas as pd
import numpy as np
from datetime import datetime
from autogluon.timeseries import TimeSeriesDataFrame
import dataclasses
import os
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

@dataclass(frozen=True)
class StoreParams:
    """Parameters defining sales patterns for a store.
    
    Args:
        amplitude: Height of the periodic sales pattern
        frequency: Frequency of the sales cycle in days
        baseline: Base sales level
        promotion_effect: Sales lift when promotion is active
        temperature_effect: Sales change per degree Celsius (non-linear relationship)
        price_effect: Sales change per unit price increase (with elasticity)
        holiday_effect: Sales lift during holidays
        weekend_effect: Sales change during weekends
        competitor_effect: Impact of competitor promotions
        inventory_effect: Impact of inventory levels
        marketing_effect: Effect of marketing spend
        seasonal_amplitude: Strength of yearly seasonality
        promotion_probability: Probability of having a promotion on any given day
        competitor_promotion_probability: Probability of competitor promotion
        temperature_noise: Standard deviation of temperature noise
        sales_noise: Standard deviation of random noise added to sales
        price_elasticity: Non-linear price elasticity parameter
        temperature_optimal: Optimal temperature for sales
        marketing_decay: Decay rate of marketing effect
        trend_coefficient: Long-term trend coefficient (growth/decline)
    """
    # Basic parameters
    amplitude: float
    frequency: int
    baseline: float
    
    # Linear effect parameters
    promotion_effect: float
    weekend_effect: float
    holiday_effect: float
    
    # Non-linear effect parameters
    temperature_effect: float
    temperature_optimal: float  # Optimal temperature for sales
    price_effect: float
    price_elasticity: float  # Non-linear elasticity
    
    # Competition and marketing effects
    competitor_effect: float
    marketing_effect: float
    marketing_decay: float  # Decay rate for marketing
    
    # Inventory effects
    inventory_effect: float
    
    # Seasonality and trend
    seasonal_amplitude: float  # Yearly seasonality strength
    trend_coefficient: float  # Long-term trend
    
    # Event probabilities
    promotion_probability: float
    competitor_promotion_probability: float
    
    # Noise parameters
    temperature_noise: float
    sales_noise: float


def generate_store_params(num_stores: int = 200) -> Dict[str, StoreParams]:
    """Generate parameters for stores with IID random distributions.
    
    Args:
        num_stores: Number of stores to generate parameters for
        
    Returns:
        Dictionary mapping store IDs to their sales pattern parameters
    """
    store_params = {}
    
    # Set random parameter distributions
    amplitude_mean, amplitude_std = 80, 30
    frequency_options = [7, 14, 30, 90]  # weekly, bi-weekly, monthly, quarterly
    baseline_mean, baseline_std = 200, 80
    
    # Linear effects
    promotion_effect_mean, promotion_effect_std = 30, 10
    weekend_effect_mean, weekend_effect_std = 15, 5
    holiday_effect_mean, holiday_effect_std = 40, 15
    
    # Non-linear effects
    temperature_effect_mean, temperature_effect_std = 0.3, 0.1
    temperature_optimal_mean, temperature_optimal_std = 22, 3  # Optimal temperature varies by store
    price_effect_mean, price_effect_std = -12, 5
    price_elasticity_mean, price_elasticity_std = 1.5, 0.3  # Values >1 mean elastic demand
    
    # Competition and marketing
    competitor_effect_mean, competitor_effect_std = -15, 5
    marketing_effect_mean, marketing_effect_std = 0.4, 0.1
    marketing_decay_mean, marketing_decay_std = 0.85, 0.05  # Decay rate (0-1)
    
    # Inventory
    inventory_effect_mean, inventory_effect_std = 0.2, 0.05
    
    # Seasonality and trend
    seasonal_amplitude_mean, seasonal_amplitude_std = 30, 10
    trend_coefficient_mean, trend_coefficient_std = 0.01, 0.005  # Small growth trend
    
    # Event probabilities
    promotion_probability_mean, promotion_probability_std = 0.2, 0.05
    competitor_probability_mean, competitor_probability_std = 0.15, 0.05
    
    # Noise
    temperature_noise_mean, temperature_noise_std = 2.0, 0.5
    sales_noise_mean, sales_noise_std = 10.0, 3.0
    
    for i in range(1, num_stores + 1):
        store_id = f'Store_{i}'
        
        # Generate random parameters from normal distributions, with appropriate constraints
        amplitude = max(10, np.random.normal(amplitude_mean, amplitude_std))
        frequency = np.random.choice(frequency_options)
        baseline = max(50, np.random.normal(baseline_mean, baseline_std))
        
        # Linear effects
        promotion_effect = max(0, np.random.normal(promotion_effect_mean, promotion_effect_std))
        weekend_effect = np.random.normal(weekend_effect_mean, weekend_effect_std)
        holiday_effect = max(0, np.random.normal(holiday_effect_mean, holiday_effect_std))
        
        # Non-linear effects
        temperature_effect = np.random.normal(temperature_effect_mean, temperature_effect_std)
        temperature_optimal = np.random.normal(temperature_optimal_mean, temperature_optimal_std)
        price_effect = min(0, np.random.normal(price_effect_mean, price_effect_std))
        price_elasticity = max(1.0, np.random.normal(price_elasticity_mean, price_elasticity_std))
        
        # Competition and marketing
        competitor_effect = min(0, np.random.normal(competitor_effect_mean, competitor_effect_std))
        marketing_effect = max(0, np.random.normal(marketing_effect_mean, marketing_effect_std))
        marketing_decay = min(1.0, max(0.5, np.random.normal(marketing_decay_mean, marketing_decay_std)))
        
        # Inventory
        inventory_effect = np.random.normal(inventory_effect_mean, inventory_effect_std)
        
        # Seasonality and trend
        seasonal_amplitude = max(0, np.random.normal(seasonal_amplitude_mean, seasonal_amplitude_std))
        trend_coefficient = np.random.normal(trend_coefficient_mean, trend_coefficient_std)
        
        # Event probabilities
        promotion_probability = max(0.05, min(0.4, np.random.normal(promotion_probability_mean, promotion_probability_std)))
        competitor_promotion_probability = max(0.05, min(0.3, np.random.normal(competitor_probability_mean, competitor_probability_std)))
        
        # Noise
        temperature_noise = max(0.5, np.random.normal(temperature_noise_mean, temperature_noise_std))
        sales_noise = max(2.0, np.random.normal(sales_noise_mean, sales_noise_std))
        
        store_params[store_id] = StoreParams(
            # Basic parameters
            amplitude=round(amplitude, 1),
            frequency=int(frequency),
            baseline=round(baseline, 1),
            
            # Linear effects
            promotion_effect=round(promotion_effect, 1),
            weekend_effect=round(weekend_effect, 1),
            holiday_effect=round(holiday_effect, 1),
            
            # Non-linear effects
            temperature_effect=round(temperature_effect, 2),
            temperature_optimal=round(temperature_optimal, 1),
            price_effect=round(price_effect, 1),
            price_elasticity=round(price_elasticity, 2),
            
            # Competition and marketing
            competitor_effect=round(competitor_effect, 1),
            marketing_effect=round(marketing_effect, 2),
            marketing_decay=round(marketing_decay, 2),
            
            # Inventory
            inventory_effect=round(inventory_effect, 2),
            
            # Seasonality and trend
            seasonal_amplitude=round(seasonal_amplitude, 1),
            trend_coefficient=round(trend_coefficient, 3),
            
            # Event probabilities
            promotion_probability=round(promotion_probability, 2),
            competitor_promotion_probability=round(competitor_promotion_probability, 2),
            
            # Noise parameters
            temperature_noise=round(temperature_noise, 1),
            sales_noise=round(sales_noise, 1)
        )
    
    return store_params


def generate_date_range(start_year: int = 2020, num_years: int = 3) -> pd.DatetimeIndex:
    """Generate daily dates for multiple years.
    
    Args:
        start_year: The starting year for the date range
        num_years: Number of years to generate
        
    Returns:
        DatetimeIndex containing daily dates for the specified years
    """
    start_date = datetime(start_year, 1, 1)
    end_date = datetime(start_year + num_years - 1, 12, 31)
    return pd.date_range(start=start_date, end=end_date, freq='D')


def generate_store_sales(store_params: Dict[str, StoreParams], 
                        dates: pd.DatetimeIndex) -> pd.DataFrame:
    """Generate synthetic sales data with realistic non-linear patterns and complex covariates.
    
    Args:
        store_params: Dictionary mapping store IDs to their parameters
        dates: DatetimeIndex of dates to generate data for
        
    Returns:
        DataFrame containing sales data for all stores
    """
    all_data = []
    stores = list(store_params.keys())
    
    # Define holidays (simplified for illustration)
    holidays = [
        # Major US holidays for 2020-2023
        "2020-01-01", "2020-01-20", "2020-02-17", "2020-05-25", "2020-07-04", "2020-09-07", 
        "2020-10-12", "2020-11-11", "2020-11-26", "2020-12-25",
        "2021-01-01", "2021-01-18", "2021-02-15", "2021-05-31", "2021-07-04", "2021-09-06", 
        "2021-10-11", "2021-11-11", "2021-11-25", "2021-12-25",
        "2022-01-01", "2022-01-17", "2022-02-21", "2022-05-30", "2022-07-04", "2022-09-05", 
        "2022-10-10", "2022-11-11", "2022-11-24", "2022-12-25",
        "2023-01-01", "2023-01-16", "2023-02-20", "2023-05-29", "2023-07-04", "2023-09-04", 
        "2023-10-09", "2023-11-11", "2023-11-23", "2023-12-25"
    ]
    holidays = set(pd.to_datetime(holidays))
    
    for store in stores:
        # Get parameters for this store
        params = store_params[store]
        store_data = []
        
        # Initialize marketing effect (with decay over time)
        marketing_spend = 0
        
        # Initialize inventory
        inventory = 100  # Start with 100% inventory
        
        for t, date in enumerate(dates):
            # Is it weekend?
            is_weekend = 1 if date.dayofweek >= 5 else 0
            
            # Is it holiday?
            is_holiday = 1 if date in holidays else 0
            
            # Generate covariates
            # 1. Promotion (binary): occasional promotions with store-specific probability
            promotion = 1 if np.random.random() < params.promotion_probability else 0
            
            # 2. Competitor promotion
            competitor_promotion = 1 if np.random.random() < params.competitor_promotion_probability else 0
            
            # 3. Temperature (continuous): seasonal pattern + store-specific noise
            day_of_year = date.day_of_year
            temperature = 20 + 15 * np.sin(2 * np.pi * day_of_year / 365) + np.random.normal(0, params.temperature_noise)
            
            # 4. Price with small random variations
            base_price = 10
            price_modifier = 0.9 if promotion else 1.0
            price = base_price * price_modifier * (1 + 0.05 * np.sin(t/30))  # Small price fluctuations
            
            # 5. Marketing spend (occasional marketing campaigns)
            if np.random.random() < 0.03:  # New marketing campaign starts with 3% probability
                marketing_spend = 100 * (0.7 + 0.6 * np.random.random())  # Random spend level
            else:
                # Marketing effect decays over time
                marketing_spend *= params.marketing_decay
            
            # 6. Inventory levels (varies based on past sales and restocking)
            if t % 14 == 0:  # Restock every 14 days
                inventory = min(100, inventory + 70)  # Restock up to 70%
            
            # 7. Calculate yearly seasonality effect
            yearly_seasonal_effect = params.seasonal_amplitude * np.sin(2 * np.pi * day_of_year / 365)
            
            # 8. Calculate long-term trend
            trend_effect = params.trend_coefficient * t
            
            # Generate periodic component (sine wave)
            periodic = params.amplitude * np.sin(2 * np.pi * t / params.frequency)
            
            # NON-LINEAR EFFECTS:
            
            # 1. Price elasticity (non-linear)
            # Higher elasticities mean greater response to price changes
            price_effect = params.price_effect * (price / base_price) ** params.price_elasticity
            
            # 2. Temperature effect (quadratic - optimal temperature with falloff)
            # Sales decrease when temperature is too high or too low compared to optimal
            temp_deviation = temperature - params.temperature_optimal
            temperature_effect = params.temperature_effect * (1 - 0.03 * temp_deviation**2)
            
            # 3. Diminishing returns on promotion when competitor also promotes
            promotion_effect = params.promotion_effect
            if promotion and competitor_promotion:
                # Reduce effectiveness when both promote simultaneously
                promotion_effect *= 0.7
            
            # 4. Inventory effect (non-linear - very low inventory hurts sales)
            inventory_effect = 1.0  # Default multiplier
            if inventory < 30:
                # Sharp drop in sales when inventory is low
                inventory_effect = 0.5 + 0.5 * (inventory / 30)
            
            # 5. Marketing effect with diminishing returns
            marketing_effect = params.marketing_effect * np.sqrt(marketing_spend)
            
            # Combine all effects
            
            # Base components
            base_component = params.baseline + periodic + yearly_seasonal_effect + trend_effect
            
            # Linear components
            linear_effects = (
                is_weekend * params.weekend_effect +
                is_holiday * params.holiday_effect +
                competitor_promotion * params.competitor_effect
            )
            
            # Non-linear components
            nonlinear_effects = (
                promotion * promotion_effect +
                temperature_effect +
                price_effect +
                marketing_effect
            )
            
            # Calculate target sales with all effects and noise
            raw_sales = base_component + linear_effects + nonlinear_effects
            sales = max(0, raw_sales * inventory_effect + np.random.normal(0, params.sales_noise))
            
            # Update inventory based on sales
            inventory_used = min(inventory, sales / 10)  # Assume each 10 sales units uses 1% inventory
            inventory = max(0, inventory - inventory_used)

            # Create row
            row = {
                'timestamp': date,
                'series_id': store,
                'target': round(sales, 2),
                'promotion': promotion,
                'competitor_promotion': competitor_promotion,
                'temperature': round(temperature, 2),
                'price': round(price, 2),
                'is_weekend': is_weekend,
                'is_holiday': is_holiday,
                'marketing_spend': round(marketing_spend, 2),
                'inventory': round(inventory, 2)
            }
            store_data.append(row)
        
        all_data.extend(store_data)
    
    return pd.DataFrame(all_data)


def create_time_series_df(sales_df: pd.DataFrame) -> TimeSeriesDataFrame:
    """Convert sales DataFrame to AutoGluon TimeSeriesDataFrame.
    
    Args:
        sales_df: DataFrame containing sales data
        
    Returns:
        AutoGluon TimeSeriesDataFrame for time series modeling
    """
    return TimeSeriesDataFrame.from_data_frame(
        df=sales_df,
        id_column='series_id',
        timestamp_column='timestamp'
    )


def print_statistics(tsdf: TimeSeriesDataFrame) -> None:
    """Print summary statistics about the time series data.
    
    Args:
        tsdf: TimeSeriesDataFrame to analyze
    """
    print(f"TimeSeriesDataFrame shape: {tsdf.shape}")
    print(f"Number of time series: {len(tsdf.item_ids)}")
    print(f"Available features: {tsdf.columns.tolist()}")
    print("\nSample of the data:")
    print(tsdf.head(10))

    # Basic statistics by store
    print("\nSales Statistics by Store:")
    for store in tsdf.item_ids:
        store_data = tsdf.loc[store]
        print(f"\n{store}:")
        print(f"Mean: {store_data['target'].mean():.2f}")
        print(f"Min: {store_data['target'].min():.2f}")
        print(f"Max: {store_data['target'].max():.2f}")
        print(f"Std Dev: {store_data['target'].std():.2f}")


def save_data(tsdf: TimeSeriesDataFrame, data_dir: str = './tests/hopformer/data') -> None:
    """Save the time series data to CSV and pickle formats.
    
    Args:
        tsdf: TimeSeriesDataFrame to save
        data_dir: Directory to save data files
        
    Raises:
        OSError: If directory creation fails
    """
    try:
        # Ensure the directory exists
        os.makedirs(data_dir, exist_ok=True)
        
        # Define file paths
        csv_path = os.path.join(data_dir, 'store_sales_data.csv')
        pkl_path = os.path.join(data_dir, 'store_sales_data.pkl')
        
        # Save files
        tsdf.to_csv(csv_path)
        tsdf.to_pickle(pkl_path)
        print(f"\nData saved to '{csv_path}' and '{pkl_path}'")
    except OSError as err:
        raise OSError(f"Failed to create directory or save files: {err}") from err


def plot_stores_sample(tsdf: TimeSeriesDataFrame, 
                      store_params: Dict[str, StoreParams],
                      plot_dir: str = './tests/hopformer/plots',
                      sample_size: int = 4) -> None:
    """Plot sales data and covariates for a sample of stores.
    
    Args:
        tsdf: TimeSeriesDataFrame containing store data
        store_params: Dictionary mapping store IDs to their parameters
        plot_dir: Directory to save the visualization
        sample_size: Number of stores to include in the visualization
        
    Raises:
        ImportError: If matplotlib is not available
        OSError: If directory creation fails
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib as mpl
        
        # Ensure the plot directory exists
        os.makedirs(plot_dir, exist_ok=True)
        
        # Set publication-quality settings
        plt.rcParams.update({
            # Use serif font for publication
            'font.family': 'serif',
            'font.serif': ['Times New Roman', 'DejaVu Serif', 'Computer Modern Roman'],
            'font.size': 9,
            'axes.labelsize': 10,
            'axes.titlesize': 11,
            'xtick.labelsize': 8,
            'ytick.labelsize': 8,
            'legend.fontsize': 8,
            'figure.titlesize': 12,
            
            # Remove top and right spines
            'axes.spines.top': False,
            'axes.spines.right': False,
            
            # Set ticks inward
            'xtick.direction': 'in',
            'ytick.direction': 'in',
            
            # Line widths
            'lines.linewidth': 1.5,
            'axes.linewidth': 0.8,
            'grid.linewidth': 0.5,
            
            # Figure size - optimized for two-column layout
            'figure.figsize': (7.2, 4.45),  # Optimal for two-column, fits full page width
            'figure.dpi': 300,
        })
        
        # Set up a color-blind safe palette (Tableau 10)
        tableau10 = ['#4e79a7', '#f28e2c', '#e15759', '#76b7b2', 
                     '#59a14f', '#edc949', '#af7aa1', '#ff9da7', 
                     '#9c755f', '#bab0ab']
        
        # Randomly sample stores for visualization
        all_stores = list(store_params.keys())
        sample_stores = sorted(random.sample(all_stores, min(sample_size, len(all_stores))))
        
        # Create output paths
        sales_path = os.path.join(plot_dir, 'store_sales_sample.png')
        covariates_path = os.path.join(plot_dir, 'store_covariates_sample.png')
        params_path = os.path.join(plot_dir, 'store_params_distribution.png')
        
        # Create sales figure - 2×2 layout
        fig1, axes1 = plt.subplots(2, 2, figsize=(7.2, 4.45), constrained_layout=True)
        axes1 = axes1.flatten()
        
        # Create covariates figure - 2×2 layout
        fig2, axes2 = plt.subplots(2, 2, figsize=(7.2, 4.45), constrained_layout=True)
        axes2 = axes2.flatten()
        
        for i, store in enumerate(sample_stores):
            params = store_params[store]
            store_data = tsdf.loc[store]
            
            # Sales plot
            ax_sales = axes1[i]
            ax_sales.plot(store_data.index, store_data['target'], 
                         color=tableau10[0], linewidth=1.5)
            
            # Format x-axis for quarterly ticks
            ax_sales.xaxis.set_major_locator(mpl.dates.MonthLocator([1, 7]))
            ax_sales.xaxis.set_major_formatter(mpl.dates.DateFormatter('%b %Y'))
            
            # Add title with store name
            ax_sales.set_title(f'{store}', fontsize=10)
            
            # Add y-label (sales)
            if i % 2 == 0:  # Left column
                ax_sales.set_ylabel('Sales', fontsize=9)
            
            # Add x-label (date) for bottom row
            if i >= 2:  # Bottom row
                ax_sales.set_xlabel('Date', fontsize=9)
            
            # Add subtle gridlines
            ax_sales.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
            
            # Add annotation with parameters
            param_text = (f"B:{params.baseline}, A:{params.amplitude}, F:{params.frequency}d\n"
                         f"Noise:{params.sales_noise:.1f}")
            ax_sales.annotate(param_text, xy=(0.05, 0.95), xycoords='axes fraction',
                             ha='left', va='top', fontsize=7, alpha=0.7)
            
            # Covariates plot
            ax_cov = axes2[i]
            
            # Plot temperature
            temp_line = ax_cov.plot(store_data.index, store_data['temperature'], 
                                   color=tableau10[1], alpha=0.7, linewidth=1.5, 
                                   label='Temp')
            
            # Format x-axis for quarterly ticks
            ax_cov.xaxis.set_major_locator(mpl.dates.MonthLocator([1, 7]))
            ax_cov.xaxis.set_major_formatter(mpl.dates.DateFormatter('%b %Y'))
            
            # Add title
            ax_cov.set_title(f'{store} Covariates', fontsize=10)
            
            # Add y-label for left column
            if i % 2 == 0:  # Left column
                ax_cov.set_ylabel('Value', fontsize=9)
            
            # Add x-label for bottom row
            if i >= 2:  # Bottom row
                ax_cov.set_xlabel('Date', fontsize=9)
            
            # Add subtle gridlines
            ax_cov.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
            
            # Add promotion markers
            promotion_dates = store_data[store_data['promotion'] == 1].index
            ax_cov.scatter(promotion_dates, np.zeros(len(promotion_dates)), 
                         marker='|', s=80, color=tableau10[2], alpha=0.8)
            
            # Add parameter annotation
            effect_text = (f"Temp:{params.temperature_effect:.2f}, "
                         f"Promo:{params.promotion_effect:.1f}, "
                         f"P(promo):{params.promotion_probability:.2f}")
            ax_cov.annotate(effect_text, xy=(0.05, 0.95), xycoords='axes fraction',
                           ha='left', va='top', fontsize=7, alpha=0.7)
        
        # Create a shared legend for the second figure
        from matplotlib.lines import Line2D
        
        legend_elements = [
            Line2D([0], [0], color=tableau10[1], lw=1.5, alpha=0.7, label='Temperature'),
            Line2D([0], [0], marker='|', color=tableau10[2], markersize=8, 
                 linestyle='None', alpha=0.8, label='Promotion')
        ]
        
        fig2.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 0.01),
                   ncol=2, frameon=False, fontsize=8)
        
        # Create a figure to show parameter distributions
        fig3, axes3 = plt.subplots(3, 3, figsize=(7.2, 6.5), constrained_layout=True)
        
        # Extract parameters
        amplitudes = [p.amplitude for p in store_params.values()]
        frequencies = [p.frequency for p in store_params.values()]
        baselines = [p.baseline for p in store_params.values()]
        promotion_effects = [p.promotion_effect for p in store_params.values()]
        temperature_effects = [p.temperature_effect for p in store_params.values()]
        price_effects = [p.price_effect for p in store_params.values()]
        promotion_probs = [p.promotion_probability for p in store_params.values()]
        temp_noises = [p.temperature_noise for p in store_params.values()]
        sales_noises = [p.sales_noise for p in store_params.values()]
        
        # Plot histograms of parameters
        axes3[0, 0].hist(amplitudes, bins=15, color=tableau10[0], alpha=0.7)
        axes3[0, 0].set_title('Amplitude')
        
        axes3[0, 1].hist(frequencies, bins=len(set(frequencies)), color=tableau10[1], alpha=0.7)
        axes3[0, 1].set_title('Frequency (days)')
        
        axes3[0, 2].hist(baselines, bins=15, color=tableau10[2], alpha=0.7)
        axes3[0, 2].set_title('Baseline')
        
        axes3[1, 0].hist(promotion_effects, bins=15, color=tableau10[3], alpha=0.7)
        axes3[1, 0].set_title('Promotion Effect')
        
        axes3[1, 1].hist(temperature_effects, bins=15, color=tableau10[4], alpha=0.7)
        axes3[1, 1].set_title('Temperature Effect')
        
        axes3[1, 2].hist(price_effects, bins=15, color=tableau10[5], alpha=0.7)
        axes3[1, 2].set_title('Price Effect')
        
        axes3[2, 0].hist(promotion_probs, bins=15, color=tableau10[6], alpha=0.7)
        axes3[2, 0].set_title('Promotion Probability')
        
        axes3[2, 1].hist(temp_noises, bins=15, color=tableau10[7], alpha=0.7)
        axes3[2, 1].set_title('Temperature Noise')
        
        axes3[2, 2].hist(sales_noises, bins=15, color=tableau10[8], alpha=0.7)
        axes3[2, 2].set_title('Sales Noise')
        
        # Format histogram axes
        for ax in axes3.flatten():
            ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        
        # Save all figures
        fig1.savefig(sales_path, dpi=300, bbox_inches='tight')
        fig2.savefig(covariates_path, dpi=300, bbox_inches='tight')
        fig3.savefig(params_path, dpi=300, bbox_inches='tight')
        
        print(f"Visualizations saved to '{sales_path}', '{covariates_path}', and '{params_path}'")
    
    except ImportError:
        print("Matplotlib not available for visualization")
    except OSError as err:
        raise OSError(f"Failed to create directory or save plot: {err}") from err


def main(data_dir: str = './tests/hopformer/data', 
         plot_dir: str = './tests/hopformer/plots',
         num_stores: int = 200,
         num_years: int = 2):
    """Generate synthetic store sales data, analyze, and visualize it.
    
    Args:
        data_dir: Directory to save data files
        plot_dir: Directory to save plot files
        num_stores: Number of stores to generate
        num_years: Number of years of data to generate
        
    Raises:
        OSError: If directory creation fails
    """
    # Generate store parameters
    store_params = generate_store_params(num_stores)
    
    # Generate date range
    dates = generate_date_range(start_year=2020, num_years=num_years)
    
    # Generate sales data
    print(f"Generating data for {num_stores} stores over {num_years} years...")
    sales_df = generate_store_sales(store_params, dates)
    
    # Create TimeSeriesDataFrame
    tsdf = create_time_series_df(sales_df)
    
    # Print statistics
    print_statistics(tsdf)
    
    # Save data
    save_data(tsdf, data_dir)
    
    # Plot a sample of the data
    plot_stores_sample(tsdf, store_params, plot_dir, sample_size=4)


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Generate synthetic store sales data')
    parser.add_argument('--data-dir', type=str, default='./data/sales2/data',
                        help='Directory to save data files')
    parser.add_argument('--plot-dir', type=str, default='./data/sales2/plots',
                        help='Directory to save plot files')
    parser.add_argument('--num-stores', type=int, default=200,
                        help='Number of stores to generate')
    parser.add_argument('--num-years', type=int, default=3,
                        help='Number of years of data to generate')
    
    args = parser.parse_args()
    
    main(
        data_dir=args.data_dir, 
        plot_dir=args.plot_dir,
        num_stores=args.num_stores,
        num_years=args.num_years
    )
