import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from autogluon.timeseries import TimeSeriesDataFrame

# Set random seed for reproducibility
np.random.seed(42)

# Generate dates for 1 year of daily data
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 12, 31)
dates = pd.date_range(start=start_date, end=end_date, freq='D')

# Create 4 different stores
stores = ['Store_1', 'Store_2', 'Store_3', 'Store_4', 'Store_5', 'Store_6', 'Store_7', 'Store_8']

# Store the parameters for each store
store_params = {
    'Store_1': {
        'amplitude': 120,
        'frequency': 90,  # quarterly
        'baseline': 300,
        'promotion_effect': 30,
        'temperature_effect': 0.4,
        'price_effect': -15
    },
    'Store_2': {
        'amplitude': 80,
        'frequency': 30,  # monthly
        'baseline': 150,
        'promotion_effect': 40,
        'temperature_effect': 0.2,
        'price_effect': -5
    },
    'Store_3': {
        'amplitude': 60,
        'frequency': 14,  # bi-weekly
        'baseline': 250,
        'promotion_effect': 25,
        'temperature_effect': 0.3,
        'price_effect': -20
    },
    'Store_4': {
        'amplitude': 50,
        'frequency': 7,  # weekly
        'baseline': 200,
        'promotion_effect': 20,
        'temperature_effect': 0.5,
        'price_effect': -10
    },
    'Store_5': {
        'amplitude': 100,
        'frequency': 14,  # bi-weekly
        'baseline': 250,
        'promotion_effect': 30,
        'temperature_effect': 0.4,
        'price_effect': -15
    },
    'Store_6': {
        'amplitude': 80,
        'frequency': 30,  # monthly
        'baseline': 150,
        'promotion_effect': 40,
        'temperature_effect': 0.2,
        'price_effect': -5
    },
    'Store_7': {
        'amplitude': 60,
        'frequency': 14,  # bi-weekly
        'baseline': 250,
        'promotion_effect': 25,
        'temperature_effect': 0.3,
        'price_effect': -20
    },
    'Store_8': {
        'amplitude': 50,
        'frequency': 7,  # weekly
        'baseline': 200,
        'promotion_effect': 20,
        'temperature_effect': 0.5,
        'price_effect': -10
    }
}

# Generate the data with periodic patterns and covariates
def generate_store_sales():
    all_data = []
    
    for store in stores:
        # Get parameters for this store
        params = store_params[store]
        amplitude = params['amplitude']
        frequency = params['frequency']
        baseline = params['baseline']
        promotion_effect = params['promotion_effect']
        temperature_effect = params['temperature_effect']
        price_effect = params['price_effect']
        
        store_data = []
        
        for t, date in enumerate(dates):
            # Generate covariates
            # 1. Promotion (binary): occasional promotions
            promotion = 1 if np.random.random() < 0.2 else 0
            
            # 2. Temperature (continuous): seasonal pattern + noise
            day_of_year = date.day_of_year
            temperature = 20 + 15 * np.sin(2 * np.pi * day_of_year / 365) + np.random.normal(0, 2)
            
            # 3. Price (continuous): random fluctuations around a mean
            price = 10 + np.random.normal(0, 1)
            
            # Generate periodic component (sine wave)
            periodic = amplitude * np.sin(2 * np.pi * t / frequency)
            
            # Linear combination of covariates
            covariate_effect = (
                promotion * promotion_effect + 
                temperature * temperature_effect + 
                price * price_effect
            )
            
            # Calculate target sales
            sales = max(0, baseline + periodic + covariate_effect + np.random.normal(0, 5))
            # sales = max(0, baseline + periodic + covariate_effect)

            # Create row
            row = {
                'timestamp': date,
                'series_id': store,
                'target': round(sales, 2),
                'promotion': promotion,
                'temperature': round(temperature, 2),
                'price': round(price, 2)
            }
            store_data.append(row)
        
        all_data.extend(store_data)
    
    return pd.DataFrame(all_data)

# Generate the data
sales_df = generate_store_sales()

# Create TimeSeriesDataFrame
tsdf = TimeSeriesDataFrame.from_data_frame(
    df=sales_df,
    id_column='series_id',
    timestamp_column='timestamp'
)

# Display info about the TimeSeriesDataFrame
print(f"TimeSeriesDataFrame shape: {tsdf.shape}")
print(f"Number of time series: {len(tsdf.item_ids)}")
print(f"Available features: {tsdf.columns.tolist()}")
print("\nSample of the data:")
print(tsdf.head(10))

# Basic statistics by store
print("\nSales Statistics by Store:")
for store in stores:
    store_data = tsdf.loc[store]
    print(f"\n{store}:")
    print(f"Mean: {store_data['target'].mean():.2f}")
    print(f"Min: {store_data['target'].min():.2f}")
    print(f"Max: {store_data['target'].max():.2f}")
    print(f"Std Dev: {store_data['target'].std():.2f}")

# Save to CSV and Parquet
tsdf.to_csv('./tests/hopformer/data/store_sales_data.csv')
tsdf.to_pickle('./tests/hopformer/data/store_sales_data.pkl')
print("\nData saved to 'store_sales_data.csv' and 'store_sales_data.pkl'")

# Optional: Plot the data
try:
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(15, 8))
    
    for i, store in enumerate(stores):
        if i>=4:
            break
        params = store_params[store]
        
        # Target plot
        plt.subplot(4, 2, i*2+1)
        store_data = tsdf.loc[store]
        plt.plot(store_data.index, store_data['target'])
        plt.title(f'{store} - Sales (Baseline: {params["baseline"]}, Amplitude: {params["amplitude"]})')
        plt.xlabel('Date')
        plt.ylabel('Sales')
        plt.grid(True, alpha=0.3)
        
        # Create labels with effect sizes
        temp_label = f'Temperature (effect: +{params["temperature_effect"]:.1f} per °C)'
        price_label = f'Price (effect: {params["price_effect"]:.1f} per unit)'
        promo_label = f'Promotion (effect: +{params["promotion_effect"]:.1f})'
        
        # Covariates plot
        plt.subplot(4, 2, i*2+2)
        plt.plot(store_data.index, store_data['temperature'], label=temp_label, alpha=0.7)
        plt.plot(store_data.index, store_data['price'] * 10, label=price_label, alpha=0.7)

        # Fixed promotion marker placement - use array of zeros with correct length
        promotion_dates = store_data[store_data['promotion'] == 1].index
        plt.scatter(promotion_dates, np.zeros(len(promotion_dates)), 
                   marker='|', s=200, color='red', label=promo_label)

        plt.title(f'{store} - Covariates (Frequency: {params["frequency"]} days)')
        plt.xlabel('Date')
        plt.ylabel('Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('./tests/hopformer/plots/store_sales_visualization.png')
    print("Visualization saved to 'store_sales_visualization.png'")
except ImportError:
    print("Matplotlib not available for visualization")
