import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import re

# File paths organized by dataset
file_paths = {
    "Electricity": {
        "ctx32": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx32.csv",
        "ctx64": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx64.csv",
        "ctx96": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx96.csv",
        "ctx128": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx128.csv",
        "ctx256": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx256.csv",
        "ctx512": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_electricity_synthetic_ctx512.csv",
    },
    "Sales2": {
        "ctx32": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx32.csv",
        "ctx64": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx64.csv",
        "ctx96": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx96.csv",
        "ctx128": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx128.csv",
        "ctx256": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx256.csv",
        "ctx512": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_nonlinear_sales_synthetic_ctx512.csv",
    },
    "Sales1": {
        "ctx32": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx32.csv",
        "ctx64": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx64.csv",
        "ctx96": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx96.csv",
        "ctx128": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx128.csv",
        "ctx256": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx256.csv",
        "ctx512": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_sales_synthetic_ctx512.csv",
    }
}

# file_paths = {
#     "EPF": {
#         "ctx32": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx32.csv",
#         "ctx64": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx64.csv",
#         "ctx128": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx128.csv",
#         "ctx256": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx256.csv",
#         "ctx512": "/home/magics/hdd/sky_ws/residual_ws/results/all_results_epf_ctx512.csv",
#     },
# }

# Model configurations to extract
model_configs = [
    # HopFormer models (bolt_model not None, regressor_types not empty)
    {"name": "HopFormer 0-shot", "bolt_model": "bolt_small", 
     "fine_tune": False, "use_lora": False, "has_regressors": True, "aggregation_strategy_name": "spa"},
    
    # Chronos models (bolt_model not None, empty regressor_types)
    {"name": "Chronos 0-shot", "bolt_model": "bolt_small", 
     "fine_tune": False, "use_lora": False, "has_regressors": False},
]

def get_latest_model(df, bolt_model=None, use_lora=None, fine_tune=None, 
                    aggregation_strategy=None, has_regressors=None, larger_than_one=None, name_tag=None):
    """Get the latest model matching the specified criteria."""
    # Start with basic filter - all rows
    mask = pd.Series(True, index=df.index)
    
    if name_tag is not None:
        mask &= df['model'] == name_tag

    # Filter by regressor types if specified
    if has_regressors is not None:
        if has_regressors:
            mask &= ~df['regressor_types'].isna() & (df['regressor_types'] != "")
        else:
            mask &= df['regressor_types'].isna() | (df['regressor_types'] == "")

        # Filter for models with more than one regressor type (comma-separated)
        def check_length(regressor_string):
            if pd.isna(regressor_string) or regressor_string == "":
                return False
            return len(regressor_string.split(',')) > 1
        
        if larger_than_one is not None:
            if larger_than_one:
                # Must have more than one regressor type
                mask &= df['regressor_types'].apply(check_length)
            else:
                # Must have exactly one regressor type
                mask &= ~df['regressor_types'].apply(check_length)

    # Add additional filters if specified
    if bolt_model is not None:
        if pd.isna(bolt_model):
            mask &= df['bolt_model'].isna()
        else:
            mask &= df['bolt_model'] == bolt_model
    else:
        mask &= df['bolt_model'].isna()
    
    if use_lora is not None:
        mask &= df['use_lora'] == use_lora
    
    if fine_tune is not None:
        mask &= df['fine_tune'] == fine_tune
    
    if aggregation_strategy is not None:
        if pd.isna(aggregation_strategy):
            mask &= df['aggregation_strategy_name'].isna()
        else:
            mask &= df['aggregation_strategy_name'] == aggregation_strategy
    
    # Filter and sort by timestamp
    filtered = df[mask].sort_values('timestamp', ascending=False)
    
    # Return the latest model if any match was found
    if not filtered.empty:
        return filtered.iloc[0]
    return None

def extract_context_length(filename):
    """Extract context length from filename."""
    match = re.search(r'ctx(\d+)', filename)
    if match:
        return int(match.group(1))
    return None

def extract_results_for_dataset(dataset_files):
    """Extract MASE results for a specific dataset across context lengths."""
    context_lengths = []
    hopformer_mase = []
    chronos_mase = []
    
    for file_key, filepath in dataset_files.items():
        # Skip if file doesn't exist
        if not os.path.exists(filepath):
            print(f"Warning: File not found - {filepath}")
            continue
        
        # Extract context length
        ctx_length = extract_context_length(file_key)
        if ctx_length is None:
            print(f"Warning: Could not parse context length from {file_key}")
            continue
        
        context_lengths.append(ctx_length)
        
        # Read CSV file
        try:
            df = pd.read_csv(filepath)
        except Exception as e:
            print(f"Error reading {filepath}: {e}")
            continue
        
        # Extract HopFormer results
        hop_model = get_latest_model(
            df,
            bolt_model=model_configs[0]["bolt_model"],
            use_lora=model_configs[0]["use_lora"],
            fine_tune=model_configs[0]["fine_tune"],
            aggregation_strategy=model_configs[0]["aggregation_strategy_name"],
            has_regressors=model_configs[0]["has_regressors"]
        )
        
        # Extract Chronos results
        chr_model = get_latest_model(
            df,
            bolt_model=model_configs[1]["bolt_model"],
            use_lora=model_configs[1]["use_lora"],
            fine_tune=model_configs[1]["fine_tune"],
            has_regressors=model_configs[1]["has_regressors"]
        )
        
        # Get MASE scores or use NaN if not available
        if hop_model is not None and 'MASE_mean' in hop_model:
            hopformer_mase.append(abs(hop_model['MASE_mean']))  # Use absolute value
        else:
            hopformer_mase.append(np.nan)
            
        if chr_model is not None and 'MASE_mean' in chr_model:
            chronos_mase.append(abs(chr_model['MASE_mean']))  # Use absolute value
        else:
            chronos_mase.append(np.nan)
    
    # Sort data points by context length
    sorted_indices = np.argsort(context_lengths)
    context_lengths = [context_lengths[i] for i in sorted_indices]
    hopformer_mase = [hopformer_mase[i] for i in sorted_indices]
    chronos_mase = [chronos_mase[i] for i in sorted_indices]
    
    return context_lengths, hopformer_mase, chronos_mase

def main():
    """Create plots for all three datasets."""
    # Set up publication-quality plot style
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times', 'Computer Modern Roman'],
        'font.size': 9,
        'axes.labelsize': 10,
        'axes.titlesize': 11,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'legend.fontsize': 8,
        'figure.titlesize': 12,
        'lines.linewidth': 1.5,
        'axes.linewidth': 0.8,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.05,
        'axes.grid': True,
        'grid.alpha': 0.3,
        # 'axes.spines.top': False,
        # 'axes.spines.right': False,
    })
    
    # Create a figure with 3 subplots - explicitly set width
    # Create a wider figure (20 inches wide, 5 inches tall)
    fig = plt.figure(figsize=(20, 3))
    
    # Create 3 subplots with specific positions
    ax1 = fig.add_subplot(131)  # 1 row, 3 cols, position 1
    ax2 = fig.add_subplot(132)  # 1 row, 3 cols, position 2
    ax3 = fig.add_subplot(133)  # 1 row, 3 cols, position 3
    
    axes = [ax1, ax2, ax3]
    
    # Colors (colorblind-friendly)
    hopformer_color = '#4e79a7'  # Blue
    chronos_color = '#f28e2c'    # Orange
    
    # Process each dataset
    for i, (dataset_name, dataset_files) in enumerate(file_paths.items()):
        # Extract results
        context_lengths, hopformer_mase, chronos_mase = extract_results_for_dataset(dataset_files)
        
        # Create plot on the respective subplot
        ax = axes[i]
        
        # Plot with markers at data points
        ax.plot(context_lengths, hopformer_mase, 'o-', color=hopformer_color, 
                linewidth=1.5, label='HopFormer 0-shot')
        ax.plot(context_lengths, chronos_mase, 's-', color=chronos_color, 
                linewidth=1.5, label='Chronos 0-shot')
        
        # Set x-axis to log scale for better visualization
        ax.set_xscale('log', base=2)
        ax.set_xticks(context_lengths)
        ax.set_xticklabels([str(x) for x in context_lengths], rotation=45)
        
        # Add reference line at y=1.0 
        ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
        
        # Set subplot title and labels
        ax.set_title(dataset_name)
        ax.set_xlabel('Context Length')
        
        # Only set y-label for the leftmost subplot
        if i == 0:
            ax.set_ylabel('MASE Score (lower is better)')
        
        # Add legend on first subplot only
        if i == 0:
            ax.legend(loc='upper right')
            
        # Print data summary
        print(f"\n{dataset_name} Dataset:")
        print("Context Length | HopFormer MASE | Chronos MASE")
        print("-" * 45)
        for ctx, hop, chr in zip(context_lengths, hopformer_mase, chronos_mase):
            print(f"{ctx:14d} | {hop:13.4f} | {chr:11.4f}")
    
    # Add a super title for the entire figure - closer to the plots
    # plt.suptitle('Effect of Context Length on Model Performance Across Datasets', 
    #              fontsize=14, y=0.98)  # y=0.98 brings it closer to the plots
    
    # Ensure proper spacing between subplots with more control
    plt.subplots_adjust(wspace=0.1, top=0.90)  # Increased wspace for more width between plots
    
    # Save as PNG and PDF with explicit DPI
    plt.savefig('./results/plots/context_length_comparison.png', dpi=300, bbox_inches='tight')
    plt.savefig('./results/plots/context_length_comparison.pdf', bbox_inches='tight')
    
    print(f"\nPlots saved as './results/plots/context_length_comparison.png' and './results/plots/context_length_comparison.pdf'")


if __name__ == "__main__":
    main()