import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from autogluon.timeseries import TimeSeriesDataFrame
import warnings

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({
    'figure.figsize': [24, 6],
    'font.family': 'serif',
    'font.serif': ['Times', 'Computer Modern Roman'],
    'font.size': 10,
    'axes.labelsize': 12,
    'axes.titlesize': 14,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
    'legend.fontsize': 9,
    'figure.titlesize': 16,
    'lines.linewidth': 2,
    'axes.linewidth': 0.8,
    'grid.alpha': 0.3,
    'axes.grid': True,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.05,
})

# Load the data
print("Loading data...")
data = pd.read_csv("./tests/hopformer/data/store_sales_data.csv")  # Original data
hopformer_forecasts = pd.read_csv("./tests/hopformer/data/forecasts_hopformer1.csv")  # Hopformer forecasts
chronos_forecasts = pd.read_csv("./tests/hopformer/data/forecasts_chronos1.csv")  # Chronos forecasts

# Convert to TimeSeriesDataFrame
data['timestamp'] = pd.to_datetime(data['timestamp'])
tsdf = TimeSeriesDataFrame.from_data_frame(
    df=data,
    id_column='item_id',
    timestamp_column='timestamp'
)

# Convert forecast data
hopformer_forecasts['timestamp'] = pd.to_datetime(hopformer_forecasts['timestamp'])
chronos_forecasts['timestamp'] = pd.to_datetime(chronos_forecasts['timestamp'])

hopformer_tsdf = TimeSeriesDataFrame.from_data_frame(
    df=hopformer_forecasts,
    id_column='item_id',
    timestamp_column='timestamp'
)

chronos_tsdf = TimeSeriesDataFrame.from_data_frame(
    df=chronos_forecasts,
    id_column='item_id',
    timestamp_column='timestamp'
)

# Extract the quantiles from the forecast dataframes
hopformer_quantile_columns = [col for col in hopformer_forecasts.columns if col not in ['item_id', 'timestamp', 'mean']]
chronos_quantile_columns = [col for col in chronos_forecasts.columns if col not in ['item_id', 'timestamp', 'mean']]

# Define prediction parameters
context_length = 32  # Show context days

# Update color scheme with consistent ground truth colors
model_colors = {
    'Context': '#008000',     # Blue
    'Actual': '#008000',      # Blue (same as context for consistency)
    'Hopformer': '#FF0000',   # Red
    'Chronos': '#0000FF'      # Green (moved from actual to Chronos)
}

# Calculate MASE for each model and store
def calculate_mase(store_id, forecast_df, original_df, seasonal_period=7):
    """
    Calculate Mean Absolute Scaled Error (MASE) between forecast and actual values.
    Seasonal period set to 7 for daily data (week seasonality).
    """
    forecast_data = forecast_df.loc[store_id]
    forecast_dates = forecast_data.index
    
    # Get actual data for the same dates (test period)
    actual_data = original_df.loc[store_id]
    actual_values = actual_data.loc[forecast_dates, 'target'].values
    
    # Get forecast mean values
    forecast_values = forecast_data['mean'].values
    
    # Get training data up to the forecast start
    forecast_start = forecast_dates[0]
    train_data = actual_data[actual_data.index < forecast_start]['target'].values
    
    # Calculate naive seasonal forecast errors on training data
    # For each point in the training set (except the first seasonal_period points),
    # calculate the absolute error of a seasonal naive forecast
    seasonal_errors = []
    for i in range(seasonal_period, len(train_data)):
        error = abs(train_data[i] - train_data[i - seasonal_period])
        seasonal_errors.append(error)
    
    # Calculate the mean of the seasonal errors
    if len(seasonal_errors) == 0 or np.mean(seasonal_errors) == 0:
        # Handle case where seasonal_errors is empty or all zeros
        warnings.warn("Cannot calculate MASE: insufficient training data or zero seasonal errors")
        return np.nan
    
    mean_seasonal_error = np.mean(seasonal_errors)
    
    # Calculate forecast errors
    forecast_errors = np.abs(actual_values - forecast_values)
    
    # Calculate MASE
    mase = np.mean(forecast_errors) / mean_seasonal_error
    
    return mase

# Function to plot forecasts for a store
def plot_store_forecast(store_id, model_name, forecast_df, quantile_columns, ax, color):
    """Plot forecasts with enhanced visuals for a specific store and model."""
    # Get store data
    store_data = tsdf.loc[store_id]
    store_forecast = forecast_df.loc[store_id]
    
    # Get time ranges
    forecast_start = store_forecast.index[0]
    context_start = forecast_start - pd.Timedelta(days=context_length)
    
    # Extract context data and actual data
    context_data = store_data[(store_data.index >= context_start) & (store_data.index < forecast_start)]
    actual_data = store_data[(store_data.index >= forecast_start)]
    
    # Calculate MASE
    mase = calculate_mase(store_id, forecast_df, tsdf)
    mase_str = f"{mase:.4f}" if not np.isnan(mase) else "N/A"
    
    # Only plot actual/context data once (with Hopformer)
    if model_name == 'Hopformer':
        # Plot context data
        ax.plot(context_data.index, context_data['target'], 
                color=model_colors['Context'], linestyle='-', linewidth=1.8, 
                label='Context', marker='', alpha=1.0)
        
        # Plot actual data (ground truth) - using same blue color
        ax.plot(actual_data.index, actual_data['target'], 
                color=model_colors['Actual'], linestyle='-', linewidth=1.8, 
                label='Actual', marker='', alpha=1.0)
        
        # Add vertical dashed line
        ax.axvline(x=forecast_start, color='#555555', linestyle='--', alpha=0.7, 
                  linewidth=1.0, label='Forecast Start')
    
    # Plot mean forecast with dashed line
    ax.plot(store_forecast.index, store_forecast['mean'], 
            color=color, linestyle='--', linewidth=2.0, 
            label=f'{model_name} (MASE: {mase_str})', 
            marker='', alpha=1.0)
    
    # Plot quantiles with clearer coloring
    quantile_pairs = [('0.1', '0.9'), ('0.2', '0.8'), ('0.3', '0.7'), ('0.4', '0.6')]
    quantile_pairs = [('0.1', '0.9')]
    for i, (low_q, high_q) in enumerate(quantile_pairs):
        if low_q in store_forecast.columns and high_q in store_forecast.columns:
            # Use progressively lighter shades based on confidence interval width
            alpha = 0.15 - (i * 0.03)  # Decrease alpha for wider intervals
            ax.fill_between(store_forecast.index, 
                           store_forecast[low_q], 
                           store_forecast[high_q],
                           color=color, alpha=alpha,
                           label=f'{int(float(low_q)*100)}%-{int(float(high_q)*100)}% PI' if i == 0 else None)
    
    # Format x-axis for better readability
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=30, ha='right')
    
    # Add subtle gridlines
    ax.grid(True, alpha=0.2, linestyle='-')
    
    return ax

# Reduce to 3 subplots with better spacing
fig, axs = plt.subplots(1, 3, figsize=(24, 6), sharex=False, sharey=False)
axs = axs.flatten()  # Flatten to easily iterate

# Select 3 stores for plotting
common_stores = ['Store_3', 'Store_5', 'Store_7']  # Reduced to 3 stores

# If fewer than 3 stores, use the available ones
if len(common_stores) < 3:
    print(f"Warning: Only {len(common_stores)} stores available for plotting")

for i, store_id in enumerate(common_stores):
    if i >= 3:  # Only use 3 spots in the 1x3 grid
        break
        
    ax = axs[i]
    
    # Plot both models on the same axis
    plot_store_forecast(store_id, 'Hopformer', hopformer_tsdf, hopformer_quantile_columns, ax, model_colors['Hopformer'])
    plot_store_forecast(store_id, 'Chronos', chronos_tsdf, chronos_quantile_columns, ax, model_colors['Chronos'])
    
    # Set title and labels
    ax.set_title(f'Store: {store_id}', fontsize=14)
    ax.set_xlabel('Date', fontsize=12)
    ax.set_ylabel('Sales', fontsize=12)
    
    # Handle legend - show only unique entries
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc='best')

# Update the finalize_plots function for clearer styling
def finalize_plots(fig, axs):
    """Add final touches to improve the overall figure appearance."""
    # Add a descriptive title with positioning
    fig.suptitle('Comparison of Hopformer and Chronos Sales Forecasts', 
                fontsize=16, y=0.90)
    
    # Create a single unified legend
    handles, labels = axs[0].get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    
    # Place the legend at the bottom with more columns to fit in one row
    fig.legend(by_label.values(), by_label.keys(), 
               loc='lower center', ncol=6, frameon=False, 
               columnspacing=1.8, handletextpad=0.7,
               bbox_to_anchor=(0.5, 0.01))
    
    # Add more space at the bottom for the legend
    plt.tight_layout(rect=[0, 0.08, 1, 0.92])
    
    # Adjust the spacing between subplots
    plt.subplots_adjust(wspace=0.20)

    # Remove individual legends from each subplot since we have a main one
    for ax in axs:
        ax.get_legend().remove() if ax.get_legend() else None
        
        # Show all spines with black color and medium weight
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(1.5)  # Medium weight for clear borders
            spine.set_color('black')  # Set spine color to black
        
        # Keep ticks only on bottom and left
        ax.tick_params(top=False, right=False, 
                      which='both', direction='out', length=4, width=0.8,
                      color='black')  # Also set tick color to black
        
        # Set grid style to match the example (lighter, behind data)
        ax.grid(True, alpha=0.3, linestyle='-', color='#CCCCCC')

# Call this function before saving
finalize_plots(fig, axs)

plt.savefig('./tests/hopformer/plots/hopformer_vs_chronos.png', dpi=300, bbox_inches='tight')
plt.show()

print("Visualization saved to './tests/hopformer/plots/hopformer_vs_chronos.png'")
