import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from autogluon.timeseries import TimeSeriesDataFrame
from sklearn.preprocessing import StandardScaler
import os

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = [15, 15]
plt.rcParams['font.size'] = 12

# Define data paths
base_path = "/home/magics/hdd/sky_ws/residual_ws/tests/hopformer/data/"
file_paths = {
    'original_data': os.path.join(base_path, "store_sales_data.csv"),
    'forecasts': os.path.join(base_path, "forecasts.csv"),
    'context_residuals': os.path.join(base_path, "context_residuals.csv"),
    'residual_predictions': os.path.join(base_path, "residual_predictions.csv"),
    'regressor_context': os.path.join(base_path, "regressor_context.csv"),
    'regressor_predictions': os.path.join(base_path, "regressor_predictions.csv")
}

# Function to load data and convert to TimeSeriesDataFrame
def load_data(file_path, id_col='item_id', timestamp_col='timestamp'):
    try:
        df = pd.read_csv(file_path)
        df[timestamp_col] = pd.to_datetime(df[timestamp_col])
        
        # Check if this is one of the regressor files with different format
        if 'regressor_context' in file_path or 'regressor_predictions' in file_path:
            # These files have a column named '0' instead of 'mean'
            if '0' in df.columns:
                df.rename(columns={'0': 'mean'}, inplace=True)
        
        return TimeSeriesDataFrame.from_data_frame(
            df=df,
            id_column=id_col,
            timestamp_column=timestamp_col
        )
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None

# Load all datasets
print("Loading data...")
data = {}
for key, path in file_paths.items():
    id_col = "item_id"
    data[key] = load_data(path, id_col=id_col)
    if data[key] is not None:
        print(f"Loaded {key} with shape {data[key].shape}")

# Normalize the data before plotting
def normalize_data(data_dict, context_length=30):
    print("Normalizing data...")
    
    # Process each store separately
    store_ids = data_dict['original_data'].item_ids
    for store_id in store_ids:
        # Find the forecast start date if available
        forecast_start = None
        if 'forecasts' in data_dict and data_dict['forecasts'] is not None:
            if store_id in data_dict['forecasts'].item_ids:
                forecast_start = data_dict['forecasts'].loc[store_id].index[0]
        
        if forecast_start is not None:
            # Get context data for fitting the scaler
            context_start = forecast_start - pd.Timedelta(days=context_length)
            store_data = data_dict['original_data'].loc[store_id]
            context_data = store_data[(store_data.index >= context_start) & (store_data.index < forecast_start)]
            
            if not context_data.empty and 'target' in context_data.columns:
                # Fit scaler on context data
                scaler = StandardScaler()
                context_values = context_data['target'].values.reshape(-1, 1)
                scaler.fit(context_values)
                
                # Transform original data for this store
                store_values = store_data['target'].values.reshape(-1, 1)
                normalized_values = scaler.transform(store_values).flatten()
                data_dict['original_data'].loc[store_id, 'target'] = normalized_values
                
                # Transform forecasts for this store
                if 'forecasts' in data_dict and store_id in data_dict['forecasts'].item_ids:
                    forecast_data = data_dict['forecasts'].loc[store_id]
                    
                    # Transform the mean forecast
                    mean_values = forecast_data['mean'].values.reshape(-1, 1)
                    normalized_mean = scaler.transform(mean_values).flatten()
                    data_dict['forecasts'].loc[store_id, 'mean'] = normalized_mean
                    
                    # Transform quantiles if present
                    quantile_columns = [col for col in forecast_data.columns if col not in ['item_id', 'timestamp', 'mean']]
                    for q_col in quantile_columns:
                        q_values = forecast_data[q_col].values.reshape(-1, 1)
                        normalized_q = scaler.transform(q_values).flatten()
                        data_dict['forecasts'].loc[store_id, q_col] = normalized_q
                
                print(f"  Normalized {store_id} data (mean={scaler.mean_[0]:.2f}, std={scaler.scale_[0]:.2f})")

# Normalize the data in place
normalize_data(data, context_length=30)

# Function to plot comprehensive visualization for a store
def plot_store_comprehensive(store_id, fig=None, axs=None, context_length=128, prediction_length=24):
    if fig is None or axs is None:
        fig, axs = plt.subplots(4, 1, figsize=(15, 20), sharex=True, 
                              gridspec_kw={'height_ratios': [2, 1.5, 1.5, 1.5]})
    
    # Define colors for consistency
    colors = {
        'context': 'blue',
        'actual': 'green',
        'forecast': 'red',
        'residual': 'purple',
        'regressor': 'orange'
    }
    
    # Get store data
    original = data['original_data'].loc[store_id]
    
    # Determine time ranges
    forecast_df = data['forecasts'].loc[store_id] if store_id in data['forecasts'].item_ids else None
    forecast_start = forecast_df.index[0] if forecast_df is not None else None
    # context_length = 30  # Days of context to show
    
    if forecast_start is not None:
        if context_length is not None:
            context_start = forecast_start - pd.Timedelta(days=context_length)
        forecast_end = forecast_df.index[-1]
        
        # Extract context and actual data
        # context_data = original[(original.index >= context_start) & (original.index < forecast_start)]
        context_data = original[(original.index < forecast_start)][-context_length:]
        actual_data = original[(original.index >= forecast_start) & (original.index <= forecast_end)]
        
        # 1. Main plot with original data, forecasts and quantiles
        ax = axs[0]
        
        # Plot context and actual data
        ax.plot(context_data.index, context_data['target'], 
                color=colors['context'], linestyle='-', linewidth=2, label='Context')
        ax.plot(actual_data.index, actual_data['target'], 
                color=colors['actual'], linestyle='-', linewidth=2, label='Actual')
        
        # Plot mean forecast
        if forecast_df is not None:
            ax.plot(forecast_df.index, forecast_df['mean'], 
                    color=colors['forecast'], linestyle='--', linewidth=2, label='Forecast (Mean)')
            
            # Plot quantiles
            quantile_columns = [col for col in data['forecasts'].columns if col not in ['item_id', 'timestamp', 'mean']]
            if quantile_columns:
                quantile_values = sorted([float(q) for q in quantile_columns])
                n_quantiles = len(quantile_values)
                mid_point = n_quantiles // 2
                
                for i in range(mid_point):
                    low_q = quantile_values[i]
                    high_q = quantile_values[-(i+1)]
                    alpha = 0.2 - (i * 0.05)
                    ax.fill_between(
                        forecast_df.index, 
                        forecast_df[str(low_q)], 
                        forecast_df[str(high_q)],
                        color=colors['forecast'], alpha=max(alpha, 0.05),
                        label=f'{int(low_q*100)}%-{int(high_q*100)}% PI' if i == 0 else None
                    )
        
        ax.set_title(f'Sales Forecast for {store_id} (Normalized)', fontsize=14)
        ax.set_ylabel('Sales (Normalized)', fontsize=12)
        ax.grid(True, alpha=0.3)
        
        # 2. Regressor context and predictions plot
        ax = axs[1]
        
        # Plot regressor context if available
        if 'regressor_context' in data and store_id in data['regressor_context'].item_ids:
            reg_context = data['regressor_context'].loc[store_id][-context_length:]
            ax.plot(reg_context.index, reg_context['mean'], 
                    color=colors['regressor'], linestyle='-', linewidth=2, label='Regressor Context')
        
        # Plot regressor predictions if available
        if 'regressor_predictions' in data and store_id in data['regressor_predictions'].item_ids:
            reg_pred = data['regressor_predictions'].loc[store_id]
            ax.plot(reg_pred.index, reg_pred['mean'], 
                    color=colors['regressor'], linestyle='--', linewidth=2, label='Regressor Predictions')
        
        # Plot ground truth (actual data) for comparison
        ax.plot(actual_data.index, actual_data['target'], 
                color=colors['actual'], linestyle='-', linewidth=2, label='Actual (Normalized)')
        
        ax.set_title('Regressor Component', fontsize=14)
        ax.set_ylabel('Value', fontsize=12)
        ax.grid(True, alpha=0.3)
        
        # 3. Residual context and predictions plot
        ax = axs[2]
        
        # Plot residual context if available
        if 'context_residuals' in data and store_id in data['context_residuals'].item_ids:
            res_context = data['context_residuals'].loc[store_id][-context_length:]
            # Check if 'target' or 'mean' column exists
            value_col = 'target' if 'target' in res_context else 'mean'
            ax.plot(res_context.index, res_context[value_col], 
                    color=colors['residual'], linestyle='-', linewidth=2, label='Residual Context')
        
        # Plot residual predictions if available
        if 'residual_predictions' in data and store_id in data['residual_predictions'].item_ids:
            res_pred = data['residual_predictions'].loc[store_id]
            ax.plot(res_pred.index, res_pred['mean'], 
                    color=colors['residual'], linestyle='--', linewidth=2, label='Residual Predictions')
        
        # Calculate and plot true residuals if we have both regressor predictions and actual data
        if 'regressor_predictions' in data and store_id in data['regressor_predictions'].item_ids:
            reg_pred = data['regressor_predictions'].loc[store_id]
            common_idx = actual_data.index.intersection(reg_pred.index)
            if len(common_idx) > 0:
                true_residuals = pd.Series(
                    actual_data.loc[common_idx, 'target'].values - reg_pred.loc[common_idx, 'mean'].values,
                    index=common_idx
                )
                ax.plot(common_idx, true_residuals, 
                        color='black', linestyle=':', linewidth=2, label='True Residuals')
        
        ax.set_title('Residual Component', fontsize=14)
        ax.set_ylabel('Value', fontsize=12)
        ax.grid(True, alpha=0.3)
        
        # 4. Covariates plot
        ax = axs[3]
        
        # Get data covering both context and forecast periods
        # full_period = original[(original.index >= context_start) & (original.index <= forecast_end)]
        full_period = original[(original.index <= forecast_end)][-(context_length+prediction_length):]

        # Plot temperature
        ax.plot(full_period.index, full_period['temperature'], 
                label='Temperature', color='orange')
        
        # Plot price (scaled for visibility)
        ax.plot(full_period.index, full_period['price'] * 10, 
                label='Price (×10)', color='purple')
        
        # Mark promotions with vertical spans
        for idx, row in full_period[full_period['promotion'] == 1].iterrows():
            ax.axvspan(idx, idx + pd.Timedelta(days=1), alpha=0.2, color='red')
        
        # Add a mock entry for the legend
        ax.plot([], [], color='red', alpha=0.3, linewidth=10, label='Promotion')
        
        ax.set_title('Covariates', fontsize=14)
        ax.set_xlabel('Date', fontsize=12)
        ax.set_ylabel('Value', fontsize=12)
        ax.grid(True, alpha=0.3)
        
        # Add vertical line at forecast start to all subplots
        for a in axs:
            a.axvline(x=forecast_start, color='black', linestyle='--', alpha=0.5, 
                     label='Forecast Start' if a == axs[0] else None)
        
        # Format x-axis to show dates nicely
        plt.setp(axs[3].xaxis.get_majorticklabels(), rotation=45, ha='right')
        axs[3].xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        
        # Add legends to all subplots
        for a in axs:
            handles, labels = a.get_legend_handles_labels()
            by_label = dict(zip(labels, handles))
            a.legend(by_label.values(), by_label.keys(), loc='best')
    
    return fig, axs

# Create comprehensive plots for each store
store_ids = data['original_data'].item_ids
store_ids = [store_ids[4], store_ids[6], store_ids[2], store_ids[3]]

for store_id in store_ids:
    if store_id in data['forecasts'].item_ids:
        print(f"Creating visualization for {store_id}...")
        fig, axs = plot_store_comprehensive(store_id, context_length=128, prediction_length=24)
        
        plt.tight_layout()
        plot_path = f"./tests/hopformer/plots/{store_id}_normalized.png"
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
        
        print(f"Visualization saved to {plot_path}")

# Create a combined visualization with all stores in one figure
if len(store_ids) > 0:
    n_stores = len([s for s in store_ids if s in data['forecasts'].item_ids])
    if n_stores > 0:
        print("\nCreating combined visualization for all stores...")
        fig, all_axs = plt.subplots(4, n_stores, figsize=(7*n_stores, 20), sharex='col')
        
        # Handle case with only one store
        if n_stores == 1:
            all_axs = all_axs.reshape(-1, 1)
        
        idx = 0
        for store_id in store_ids:
            if store_id in data['forecasts'].item_ids:
                # Extract just the column for this store
                store_axs = all_axs[:, idx]
                _, _ = plot_store_comprehensive(store_id, fig, store_axs)
                idx += 1
        
        plt.tight_layout()
        plot_path = "./tests/hopformer/plots/all_stores_normalized.png"
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plot_path = "./tests/hopformer/plots/all_stores_normalized.pdf"
        plt.savefig(plot_path, format='pdf', bbox_inches='tight')
        plt.close(fig)
        
        print(f"Combined visualization saved to {plot_path}")

print("Visualizations complete!")