"""
Plot Combined Results from Separate Scale Runs

This script loads YAML results from your hyperparameter_sweep_multiscale runs
and creates a combined plot with 3 lines per subplot (one per scale).

Usage:
    python plot_combined_scales.py \
        --small path/to/small_results.yaml \
        --medium path/to/medium_results.yaml \
        --large path/to/large_results.yaml
        
Or find files automatically:
    python plot_combined_scales.py --dir hyperparam_sweep_multiscale_results --hyperparam learning_rate
"""

import os
import sys
import argparse
import yaml
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import time
import warnings

warnings.filterwarnings('ignore')


# ============================================================================
# ICML STYLE CONFIGURATION
# ============================================================================

def setup_icml_style():
    """Configure matplotlib for ICML paper style"""
    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("husl")
    
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica', 'Liberation Sans'],
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 12,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 12,
        'figure.titlesize': 14,
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.05,
        'figure.constrained_layout.use': True,
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.linewidth': 0.8,
        'lines.linewidth': 2,
        'lines.markersize': 7,
        'legend.frameon': True,
        'legend.framealpha': 0.9,
        'legend.edgecolor': '0.8',
        'axes.axisbelow': True,
    })

setup_icml_style()


def load_yaml_results(filepath):
    """Load results from YAML file generated by hyperparameter_sweep_multiscale"""
    with open(filepath, 'r') as f:
        data = yaml.safe_load(f)
    return data


def find_results_files(directory, hyperparam_name, problem_type='BiTSP'):
    """Find result files in directory matching the pattern"""
    pattern = f"sweep_multiscale_{hyperparam_name}_{problem_type}_*.yaml"
    files = glob.glob(os.path.join(directory, pattern))
    return sorted(files)


def plot_combined_results(
    results_files: dict,  # {'small': 'path.yaml', 'medium': 'path.yaml', 'large': 'path.yaml'}
    output_dir: str = 'combined_plots',
    metrics_to_plot: list = None,
    figure_size: tuple = None
):
    """
    Create combined plot with 3 lines per subplot (one per scale)
    
    Parameters:
    -----------
    results_files : dict
        Dictionary mapping scale name to YAML file path
    output_dir : str
        Directory to save plots
    metrics_to_plot : list
        Which metrics to plot
    figure_size : tuple
        Custom figure size (width, height)
    """
    
    if metrics_to_plot is None:
        metrics_to_plot = ['hypervolume', 'runtime', 'solutions', 'tour_length']
    
    print("\n" + "="*80)
    print("COMBINING AND PLOTTING RESULTS")
    print("="*80)
    
    # Load all results
    all_data = {}
    hyperparam_name = None
    hyperparam_values = None
    problem_type = None
    
    for scale, filepath in results_files.items():
        print(f"\nLoading {scale}: {filepath}")
        data = load_yaml_results(filepath)
        
        # Extract metadata from first file
        if hyperparam_name is None:
            hyperparam_name = data['hyperparam_name']
            hyperparam_values = data['hyperparam_values']
            problem_type = data['problem_type']
        
        all_data[scale] = data['results']
        print(f"  Algorithms found: {list(data['results'].keys())}")
    
    print(f"\nHyperparameter: {hyperparam_name}")
    print(f"Values: {hyperparam_values}")
    print(f"Problem: {problem_type}")
    print(f"Scales: {list(all_data.keys())}")
    
    # Convert to DataFrame
    data_rows = []
    
    for scale, scale_results in all_data.items():
        for alg_name, alg_data in scale_results.items():
            # alg_data might have scale as key (from multiscale format) or hp_value directly
            if scale in alg_data:
                # Format: results[alg][scale][hp_value] = runs
                hp_dict = alg_data[scale]
            else:
                # Format: results[alg][hp_value] = runs (single scale file)
                hp_dict = alg_data
            
            for hp_value_str, runs in hp_dict.items():
                # Convert hp_value back to number
                try:
                    hp_value = float(hp_value_str)
                except (ValueError, TypeError):
                    hp_value = hp_value_str
                
                for run in runs:
                    row = {
                        'Algorithm': alg_name,
                        'Scale': scale,
                        hyperparam_name: hp_value,
                        'Hypervolume': run['hypervolume'],
                        'Runtime (s)': run['runtime'],
                        'Solutions': run['num_solutions'],
                    }
                    if run.get('tour_length') is not None:
                        row['Tour Length'] = run['tour_length']
                    data_rows.append(row)
    
    df = pd.DataFrame(data_rows)
    print(f"\nTotal data points: {len(df)}")
    
    # Define metric configurations
    metric_config = {
        'hypervolume': {'column': 'Hypervolume', 'ylabel': 'Hypervolume', 'marker': 'o'},
        'runtime': {'column': 'Runtime (s)', 'ylabel': 'Runtime (seconds)', 'marker': 's'},
        'solutions': {'column': 'Solutions', 'ylabel': 'Number of Solutions', 'marker': '^'},
        'tour_length': {'column': 'Tour Length', 'ylabel': 'Average Tour Length', 'marker': 'd'}
    }
    
    # Filter to available metrics
    available_metrics = {k: v for k, v in metric_config.items() 
                        if k in metrics_to_plot and v['column'] in df.columns}
    n_metrics = len(available_metrics)
    
    if n_metrics == 0:
        print("No valid metrics to plot!")
        return
    
    print(f"Plotting metrics: {list(available_metrics.keys())}")
    
    # Colors for scales
    scale_colors = {
        'small': '#2ecc71',   # Green
        'medium': '#3498db',  # Blue  
        'large': '#e74c3c'    # Red
    }
    
    # Markers for scales
    scale_markers = {
        'small': 'o',
        'medium': 's', 
        'large': '^'
    }
    
    # Line styles for algorithms
    alg_styles = {
        'UCB-Hedge': '-',
        'Thompson-Hedge': '--'
    }
    
    # Create figure
    if figure_size is None:
        figure_size = (5 * n_metrics, 4)
    
    fig, axes = plt.subplots(1, n_metrics, figsize=figure_size)
    if n_metrics == 1:
        axes = [axes]
    
    os.makedirs(output_dir, exist_ok=True)
    
    for idx, (metric, config) in enumerate(available_metrics.items()):
        ax = axes[idx]
        
        for alg in sorted(df['Algorithm'].unique()):
            for scale in ['small', 'medium', 'large']:
                if scale not in df['Scale'].unique():
                    continue
                
                mask = (df['Algorithm'] == alg) & (df['Scale'] == scale)
                scale_data = df[mask]
                
                if len(scale_data) == 0:
                    continue
                
                # Calculate statistics
                grouped = scale_data.groupby(hyperparam_name)[config['column']]
                summary = grouped.agg(['median', 
                                       lambda x: np.percentile(x, 25),
                                       lambda x: np.percentile(x, 75)]).reset_index()
                summary.columns = [hyperparam_name, 'median', 'q25', 'q75']
                
                # Label
                n_algs = df['Algorithm'].nunique()
                if n_algs > 1:
                    label = f"{alg} ({scale})"
                else:
                    label = f"{scale.capitalize()}"
                
                color = scale_colors[scale]
                linestyle = alg_styles.get(alg, '-')
                marker = scale_markers[scale]
                
                # Plot line
                ax.plot(summary[hyperparam_name], summary['median'],
                       marker=marker, label=label, color=color,
                       linestyle=linestyle, linewidth=2, markersize=7)
                
                # IQR band
                ax.fill_between(summary[hyperparam_name],
                               summary['q25'], summary['q75'],
                               alpha=0.15, color=color)
        
        ax.set_title(f'{config["ylabel"]} vs {hyperparam_name}', fontweight='bold', pad=10)
        ax.set_ylabel(config['ylabel'])
        ax.set_xlabel(hyperparam_name.replace('_', ' ').title())
        ax.legend(frameon=True, loc='best', fontsize=12)
        ax.grid(True, alpha=0.3, linestyle='--')
    
    fig.suptitle(f'Multi-Scale Hyperparameter Sweep: {problem_type} Lines: median, Bands: IQR (25th-75th percentile)',
                 fontsize=14, fontweight='bold', y=1.02)
    
    plt.tight_layout()
    
    # Save
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    plot_file = os.path.join(output_dir, f'combined_{hyperparam_name}_{problem_type}_{timestamp}.png')
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"\nPlot saved to: {plot_file}")
    plt.close()
    
    # Print summary
    print("\n" + "="*80)
    print("BEST VALUES BY SCALE (by Hypervolume)")
    print("="*80)
    
    for alg in sorted(df['Algorithm'].unique()):
        print(f"\n{alg}:")
        for scale in ['small', 'medium', 'large']:
            mask = (df['Algorithm'] == alg) & (df['Scale'] == scale)
            scale_data = df[mask]
            if len(scale_data) > 0:
                best_hp = scale_data.groupby(hyperparam_name)['Hypervolume'].mean().idxmax()
                best_hv = scale_data[scale_data[hyperparam_name] == best_hp]['Hypervolume'].mean()
                print(f"  {scale}: best {hyperparam_name}={best_hp}, HV={best_hv:.4f}")
    
    return df, plot_file


def main():
    parser = argparse.ArgumentParser(
        description='Plot combined results from separate scale runs'
    )
    
    # Option 1: Specify files directly
    parser.add_argument('--small', type=str, help='Path to small scale results YAML')
    # parser.add_argument('--medium', type=str, help='Path to medium scale results YAML')
    parser.add_argument('--large', type=str, help='Path to large scale results YAML')
    
    # Option 2: Auto-find in directory
    parser.add_argument('--dir', type=str, help='Directory containing result files')
    parser.add_argument('--hyperparam', type=str, help='Hyperparameter name (for auto-find)')
    parser.add_argument('--problem', type=str, default='BiTSP', help='Problem type')
    
    # Output options
    parser.add_argument('--output-dir', type=str, default='combined_plots')
    parser.add_argument('--metrics', type=str, nargs='+', 
                       default=['hypervolume', 'runtime', 'solutions', 'tour_length'])
    parser.add_argument('--figsize', type=float, nargs=2, default=None)
    
    args = parser.parse_args()
    
    # Determine files
    if args.small and args.large: # and args.medium 
        results_files = {
            'small': args.small,
            # 'medium': args.medium,
            'large': args.large
        }
    elif args.dir and args.hyperparam:
        # Auto-find files
        all_files = find_results_files(args.dir, args.hyperparam, args.problem)
        print(f"Found {len(all_files)} files in {args.dir}")
        
        results_files = {}
        for f in all_files:
            if '_small_' in f or f.endswith('_small.yaml'):
                results_files['small'] = f
            elif '_medium_' in f:
                results_files['medium'] = f
            elif '_large_' in f:
                results_files['large'] = f
        
        # If files don't have scale in name, load and check
        if not results_files and all_files:
            for f in all_files:
                data = load_yaml_results(f)
                scales = data.get('problem_scales', [])
                if scales:
                    for scale in scales:
                        results_files[scale] = f
        
        print(f"Matched files: {results_files}")
    else:
        parser.error("Provide either --small/--medium/--large OR --dir with --hyperparam")
    
    if not results_files:
        print("No result files found!")
        return
    
    figure_size = tuple(args.figsize) if args.figsize else None
    
    plot_combined_results(
        results_files=results_files,
        output_dir=args.output_dir,
        metrics_to_plot=args.metrics,
        figure_size=figure_size
    )


if __name__ == "__main__":
    main()