import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import re

# File paths organized by dataset
file_paths = {
    "Electricity": {
        "ctx256_pred24": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx256_pred24.csv",
        "ctx256_pred48": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx256_pred48.csv",
        "ctx256_pred72": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx256_pred72.csv",
        "ctx256_pred96": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx256_pred96.csv",
        "ctx256_pred120": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx256_pred120.csv",
    },
    "Sales2": {
        "ctx256_pred24": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx256_pred24.csv",
        "ctx256_pred48": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx256_pred48.csv",
        "ctx256_pred72": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx256_pred72.csv",
        "ctx256_pred96": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx256_pred96.csv",
        "ctx256_pred120": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx256_pred120.csv",
    },
    "Sales1": {
        "ctx256_pred24": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx256_pred24.csv",
        "ctx256_pred48": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx256_pred48.csv",
        "ctx256_pred72": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx256_pred72.csv",
        "ctx256_pred96": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx256_pred96.csv",
        "ctx256_pred120": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx256_pred120.csv",
    }
}

file_paths = {
    "EPF": {
        "ctx256_pred24": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx256_pred24.csv",
        # "ctx256_pred48": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx256_pred48.csv",
        "ctx256_pred72": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx256_pred72.csv",
        # "ctx256_pred96": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx256_pred96.csv",
        "ctx256_pred120": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx256_pred120.csv",
    },
}

# Model configurations to extract
model_configs = [
    # HopFormer models (bolt_model not None, regressor_types not empty)
    {"name": "HopFormer 0-shot", "bolt_model": "bolt_small", 
     "fine_tune": False, "use_lora": False, "has_regressors": True, "aggregation_strategy_name": "spa"},
    
    # Chronos models (bolt_model not None, empty regressor_types)
    {"name": "Chronos 0-shot", "bolt_model": "bolt_small", 
     "fine_tune": False, "use_lora": False, "has_regressors": False},
]

def get_latest_model(df, bolt_model=None, use_lora=None, fine_tune=None, 
                    aggregation_strategy=None, has_regressors=None, larger_than_one=None, name_tag=None):
    """Get the latest model matching the specified criteria."""
    # Start with basic filter - all rows
    mask = pd.Series(True, index=df.index)
    
    if name_tag is not None:
        mask &= df['model'] == name_tag

    # Filter by regressor types if specified
    if has_regressors is not None:
        if has_regressors:
            mask &= ~df['regressor_types'].isna() & (df['regressor_types'] != "")
        else:
            mask &= df['regressor_types'].isna() | (df['regressor_types'] == "")

        # Filter for models with more than one regressor type (comma-separated)
        def check_length(regressor_string):
            if pd.isna(regressor_string) or regressor_string == "":
                return False
            return len(regressor_string.split(',')) > 1
        
        if larger_than_one is not None:
            if larger_than_one:
                # Must have more than one regressor type
                mask &= df['regressor_types'].apply(check_length)
            else:
                # Must have exactly one regressor type
                mask &= ~df['regressor_types'].apply(check_length)

    # Add additional filters if specified
    if bolt_model is not None:
        if pd.isna(bolt_model):
            mask &= df['bolt_model'].isna()
        else:
            mask &= df['bolt_model'] == bolt_model
    else:
        mask &= df['bolt_model'].isna()
    
    if use_lora is not None:
        mask &= df['use_lora'] == use_lora
    
    if fine_tune is not None:
        mask &= df['fine_tune'] == fine_tune
    
    if aggregation_strategy is not None:
        if pd.isna(aggregation_strategy):
            mask &= df['aggregation_strategy_name'].isna()
        else:
            mask &= df['aggregation_strategy_name'] == aggregation_strategy
    
    # Filter and sort by timestamp
    filtered = df[mask].sort_values('timestamp', ascending=False)
    
    # Return the latest model if any match was found
    if not filtered.empty:
        return filtered.iloc[0]
    return None

def extract_context_length(filename):
    """Extract context length from filename."""
    match = re.search(r'ctx(\d+)', filename)
    if match:
        return int(match.group(1))
    return None

def extract_prediction_length(filename):
    """Extract prediction length from filename."""
    match = re.search(r'pred(\d+)', filename)
    if match:
        return int(match.group(1))
    return None

def extract_results_for_dataset(dataset_files):
    """Extract MASE results for a specific dataset across prediction lengths."""
    prediction_lengths = []
    hopformer_mase = []
    chronos_mase = []
    
    for file_key, filepath in dataset_files.items():
        # Skip if file doesn't exist
        if not os.path.exists(filepath):
            print(f"Warning: File not found - {filepath}")
            continue
        
        # Extract prediction length
        pred_length = extract_prediction_length(file_key)
        if pred_length is None:
            print(f"Warning: Could not parse prediction length from {file_key}")
            continue
        
        prediction_lengths.append(pred_length)
        
        # Read CSV file
        try:
            df = pd.read_csv(filepath)
        except Exception as e:
            print(f"Error reading {filepath}: {e}")
            continue
        
        # Extract HopFormer results
        hop_model = get_latest_model(
            df,
            bolt_model=model_configs[0]["bolt_model"],
            use_lora=model_configs[0]["use_lora"],
            fine_tune=model_configs[0]["fine_tune"],
            aggregation_strategy=model_configs[0]["aggregation_strategy_name"],
            has_regressors=model_configs[0]["has_regressors"]
        )
        
        # Extract Chronos results
        chr_model = get_latest_model(
            df,
            bolt_model=model_configs[1]["bolt_model"],
            use_lora=model_configs[1]["use_lora"],
            fine_tune=model_configs[1]["fine_tune"],
            has_regressors=model_configs[1]["has_regressors"]
        )
        
        # Get MASE scores or use NaN if not available
        if hop_model is not None and 'MASE_mean' in hop_model:
            hopformer_mase.append(abs(hop_model['MASE_mean']))  # Use absolute value
        else:
            hopformer_mase.append(np.nan)
            
        if chr_model is not None and 'MASE_mean' in chr_model:
            chronos_mase.append(abs(chr_model['MASE_mean']))  # Use absolute value
        else:
            chronos_mase.append(np.nan)
    
    # Sort data points by prediction length
    sorted_indices = np.argsort(prediction_lengths)
    prediction_lengths = [prediction_lengths[i] for i in sorted_indices]
    hopformer_mase = [hopformer_mase[i] for i in sorted_indices]
    chronos_mase = [chronos_mase[i] for i in sorted_indices]
    
    return prediction_lengths, hopformer_mase, chronos_mase

def main():
    """Create plots showing impact of prediction horizon on performance."""
    # Set up publication-quality plot style
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times', 'Computer Modern Roman'],
        'font.size': 9,
        'axes.labelsize': 10,
        'axes.titlesize': 11,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'legend.fontsize': 8,
        'figure.titlesize': 12,
        'lines.linewidth': 1.5,
        'axes.linewidth': 0.8,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.05,
        'axes.grid': True,
        'grid.alpha': 0.3,
        # 'axes.spines.linewidth': 0.5,  # Reduce spine width
        # 'axes.spines.top': False,
        # 'axes.spines.right': False,
    })
    
    # Create a figure with 3 subplots
    fig = plt.figure(figsize=(20, 3))
    
    # Create 3 subplots with specific positions
    ax1 = fig.add_subplot(131)
    ax2 = fig.add_subplot(132)
    ax3 = fig.add_subplot(133)
    
    axes = [ax1, ax2, ax3]
    
    # After creating the axes, modify each one
    for ax in axes:     
        # Reduce width of the remaining spines
        ax.spines['left'].set_linewidth(0.1)
        ax.spines['bottom'].set_linewidth(0.1)

    # Colors (colorblind-friendly)
    hopformer_color = '#4e79a7'  # Blue
    chronos_color = '#f28e2c'    # Orange
    
    # Process each dataset
    for i, (dataset_name, dataset_files) in enumerate(file_paths.items()):
        # Extract results
        prediction_lengths, hopformer_mase, chronos_mase = extract_results_for_dataset(dataset_files)
        
        # Create plot on the respective subplot
        ax = axes[i]
        
        # Plot with markers at data points
        ax.plot(prediction_lengths, hopformer_mase, 'o-', color=hopformer_color, 
                linewidth=1.5, label='HopFormer 0-shot')
        ax.plot(prediction_lengths, chronos_mase, 's-', color=chronos_color, 
                linewidth=1.5, label='Chronos 0-shot')
        
        # Set x-axis tick marks and labels
        ax.set_xticks(prediction_lengths)
        ax.set_xticklabels([str(x) for x in prediction_lengths])
        
        # Add reference line at y=1.0 
        ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
        
        # Set subplot title and labels
        ax.set_title(dataset_name)
        ax.set_xlabel('Prediction Horizon (hours)')
        
        # Only set y-label for the leftmost subplot
        if i == 0:
            ax.set_ylabel('MASE Score (lower is better)')
        
        # Add legend on first subplot only
        if i == 0:
            ax.legend(loc='upper left')
            
        # Print data summary
        print(f"\n{dataset_name} Dataset:")
        print("Prediction Length | HopFormer MASE | Chronos MASE")
        print("-" * 50)
        for pred, hop, chr in zip(prediction_lengths, hopformer_mase, chronos_mase):
            print(f"{pred:16d} | {hop:13.4f} | {chr:11.4f}")
    
    # Ensure proper spacing between subplots
    plt.subplots_adjust(wspace=0.1, top=0.90)
    
    # Save plots
    plt.savefig('./results/plots/prediction_horizon_comparison.png', dpi=300, bbox_inches='tight')
    plt.savefig('./results/plots/prediction_horizon_comparison.pdf', bbox_inches='tight')
    
    print(f"\nPlots saved as './results/plots/prediction_horizon_comparison.png' and './results/plots/prediction_horizon_comparison.pdf'")


if __name__ == "__main__":
    main()