import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from autogluon.timeseries import TimeSeriesDataFrame

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = [15, 10]
plt.rcParams['font.size'] = 12

# Load the data
print("Loading data...")
data = pd.read_csv("./tests/hopformer/data/store_sales_data.csv")  # Full data
forecasts = pd.read_csv("./tests/hopformer/data/forecasts.csv")    # Forecast data

# Convert to TimeSeriesDataFrame (if not already)
data['timestamp'] = pd.to_datetime(data['timestamp'])
tsdf = TimeSeriesDataFrame.from_data_frame(
    df=data,
    id_column='item_id',
    timestamp_column='timestamp'
)

# Convert forecast data
forecasts['timestamp'] = pd.to_datetime(forecasts['timestamp'])
forecast_tsdf = TimeSeriesDataFrame.from_data_frame(
    df=forecasts,
    id_column='item_id',
    timestamp_column='timestamp'
)

# Extract the quantiles from the forecast dataframe
quantile_columns = [col for col in forecasts.columns if col not in ['item_id', 'timestamp', 'mean']]
quantile_values = sorted([float(q) for q in quantile_columns])

# Define prediction parameters
prediction_length = len(forecast_tsdf.loc[forecast_tsdf.item_ids[0]])
context_length = 192  # Adjust as needed to show relevant context

# Function to plot forecasts for a store
def plot_store_forecast(store_id, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 6))
    
    # Get store data
    store_data = tsdf.loc[store_id]
    store_forecast = forecast_tsdf.loc[store_id]
    
    # Get time ranges
    forecast_start = store_forecast.index[0]
    context_start = forecast_start - pd.Timedelta(days=context_length)
    
    # Extract context data and actual data
    context_data = store_data[store_data.index >= context_start]
    context_data = context_data[context_data.index < forecast_start]
    actual_data = store_data[store_data.index >= forecast_start]
    
    # Plot context data
    ax.plot(context_data.index, context_data['target'], 
            color='blue', linestyle='-', linewidth=2, label='Context')
    
    # Plot actual data
    ax.plot(actual_data.index, actual_data['target'], 
            color='green', linestyle='-', linewidth=2, label='Actual')
    
    # Plot mean forecast
    ax.plot(store_forecast.index, store_forecast['mean'], 
            color='red', linestyle='--', linewidth=2, label='Forecast (Mean)')
    
    # Plot quantiles - create pairs of quantiles for the fill
    n_quantiles = len(quantile_values)
    mid_point = n_quantiles // 2
    
    if n_quantiles % 2 == 0:  # Even number of quantiles
        for i in range(mid_point):
            low_q = quantile_values[i]
            high_q = quantile_values[-(i+1)]
            alpha = 0.2 - (i * 0.05)  # Decrease alpha for wider intervals
            ax.fill_between(store_forecast.index, 
                           store_forecast[str(low_q)], 
                           store_forecast[str(high_q)],
                           color='red', alpha=max(alpha, 0.05),
                           label=f'{int(low_q*100)}%-{int(high_q*100)}% Prediction Interval')
    else:  # Odd number of quantiles
        mid_q = quantile_values[mid_point]
        for i in range(mid_point):
            low_q = quantile_values[i]
            high_q = quantile_values[-(i+1)]
            alpha = 0.2 - (i * 0.05)
            ax.fill_between(store_forecast.index, 
                           store_forecast[str(low_q)], 
                           store_forecast[str(high_q)],
                           color='red', alpha=max(alpha, 0.05),
                           label=f'{int(low_q*100)}%-{int(high_q*100)}% Prediction Interval')
    
    # Format the plot
    ax.set_title(f'Forecast for {store_id}', fontsize=14)
    ax.set_xlabel('Date', fontsize=12)
    ax.set_ylabel('Sales', fontsize=12)
    
    # Format x-axis to show dates nicely
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    # Add vertical line at forecast start
    ax.axvline(x=forecast_start, color='black', linestyle='--', alpha=0.5, 
               label='Forecast Start')
    
    # Add grid and legend
    ax.grid(True, alpha=0.3)
    
    # Handle legend - show only unique entries
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc='best')
    
    return ax

# Create a plot for each store
fig, axs = plt.subplots(len(forecast_tsdf.item_ids), 1, figsize=(15, 5*len(forecast_tsdf.item_ids)))

if len(forecast_tsdf.item_ids) == 1:
    axs = [axs]  # Make it iterable for single store case

for i, store_id in enumerate(forecast_tsdf.item_ids):
    plot_store_forecast(store_id, axs[i])

plt.tight_layout()
plt.savefig('./tests/hopformer/plots/forecast_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("Visualization saved to './tests/hopformer/plots/forecast_comparison.png'")

# Create a more detailed plot showing covariates for one store (first store)
if len(forecast_tsdf.item_ids) > 0:
    example_store = forecast_tsdf.item_ids[0]
    
    # Create a figure with 2 subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), sharex=True, 
                                  gridspec_kw={'height_ratios': [2, 1]})
    
    # Plot forecast on top subplot
    plot_store_forecast(example_store, ax1)
    
    # Plot covariates on bottom subplot
    store_data = tsdf.loc[example_store]
    forecast_start = forecast_tsdf.loc[example_store].index[0]