import os
import json
import re
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set a professional plotting style
sns.set_theme(style="whitegrid")

#################################################################
# 1. Data Loading and Parsing
#################################################################

# In analyze_results.py, replace the old parsing function with this one.

def parse_filename_dynamically(filename, experiment_mode, rep_modes=['nl', 'special']):
    """
    Parses a filename dynamically based on the format:
    <Model-name>_<representation_mode>_<dataset_type>..._results.json

    This version is not hardcoded and uses the representation mode as a flexible anchor.
    """
    try:
        # 1. Isolate the main name from the known suffixes
        base_name = filename.removesuffix('_results.json')
        prefix, checkpoint_str = base_name.rsplit('-checkpoint-', 1)
        checkpoint_step = int(checkpoint_str)

        # 2. Use the representation_mode as the anchor to split the rest of the string.
        #    This is the key to making the parsing dynamic.
        parsed_rep_mode = None
        model_name = None
        dataset_type = None

        for mode in rep_modes:
            # The anchor is the representation mode itself, preceded by an underscore.
            anchor = f'_{mode}'
            
            # We use rsplit to handle cases where the model name itself might contain "_nl",
            # ensuring we split on the correct, final occurrence.
            if anchor in prefix:
                parts = prefix.rsplit(anchor, 1)
                if len(parts) == 2:
                    model_name, dataset_type = parts
                    parsed_rep_mode = mode
                    
                    # 3. Clean up the parsed parts
                    dataset_type = dataset_type.lstrip('_')
                    dataset_type = dataset_type.removesuffix('_best_move')
                    
                    break
        
        if not parsed_rep_mode:
            raise ValueError(f"Could not find a known representation mode in filename prefix: {prefix}")

        # 4. Clean up model name for prettier plots
        cleaned_model_name = model_name.replace('meta-llama_', 'Llama-').replace('Qwen_', 'Qwen-')

        return {
            'model_name': cleaned_model_name,
            'representation_mode': parsed_rep_mode,
            'training_type': dataset_type,
            'checkpoint_step': checkpoint_step,
            'experiment_mode': experiment_mode
        }

    except Exception as e:
        print(f"Warning: Could not parse filename '{filename}'. Error: {e}")
        return None

def load_and_parse_results(results_folder):
    """
    Walks through the results folder, loads all JSON files, parses their
    filenames dynamically, and returns a consolidated pandas DataFrame.
    """
    all_results = []
    for root, _, files in os.walk(results_folder):
        for filename in files:
            if filename.endswith(".json"):
                experiment_mode = os.path.basename(root)
                
                parsed_data = parse_filename_dynamically(filename, experiment_mode)
                if parsed_data is None:
                    continue

                filepath = os.path.join(root, filename)
                with open(filepath, 'r') as f:
                    content = json.load(f)

                record = {
                    **parsed_data,
                    'overall_accuracy': content['overall_stats']['accuracy_percent'],
                    'correct_predictions': content['overall_stats']['correct_predictions'],
                    'total_samples': content['overall_stats']['total_samples'],
                    'fine_grained_stats': content['fine_grained_stats']
                }
                all_results.append(record)

    return pd.DataFrame(all_results)


#################################################################
# 2. Finding the Best Checkpoint & CI Calculation
#################################################################

def find_best_checkpoints(df):
    """
    For each unique experiment (model + training type + experiment mode),
    finds the checkpoint with the highest overall accuracy.
    """
    if df.empty:
        return pd.DataFrame()
    # Use .loc to avoid potential chained indexing issues
    best_idx = df.loc[df.groupby(['model_name', 'training_type', 'experiment_mode'])['overall_accuracy'].idxmax()]
    return best_idx.copy()

def calculate_ci(correct, total, confidence=0.95):
    """
    Calculates the Wilson score interval for a binomial proportion.
    """
    if total == 0:
        return 0, 0
    z = 1.96 # for 95% confidence
    p = correct / total
    
    numerator = p + (z**2 / (2 * total))
    denominator = 1 + (z**2 / total)
    term = z * np.sqrt((p * (1 - p)) / total + (z**2 / (4 * total**2)))
    
    lower_bound = (numerator - term) / denominator
    upper_bound = (numerator + term) / denominator
    
    # Return the absolute error margins
    return p - lower_bound, upper_bound - p


#################################################################
# 3. Plotting Functions
#################################################################
# In analyze_results.py, replace the function with this definitive, robust version.
# In analyze_results.py, please use this definitive version of the plotting function.
# In analyze_results.py, please use this definitive, robust version of the function.

def plot_best_checkpoint_performance(best_df, output_dir="plots"):
    """
    Creates bar plots comparing the best checkpoint of each experiment.
    This version uses the modern ax.containers and container.get_label() API
    to robustly align custom error bars, definitively fixing all errors.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for mode in sorted(best_df['experiment_mode'].unique()):
        plt.figure(figsize=(12, 8))
        
        subset_df = best_df[best_df['experiment_mode'] == mode].copy()
        if subset_df.empty:
            continue

        x_order = sorted(subset_df['training_type'].unique())
        # Use all models for a consistent legend order across plots
        hue_order = sorted(best_df['model_name'].unique())

        ax = sns.barplot(
            data=subset_df,
            x='training_type',
            y='overall_accuracy',
            hue='model_name',
            palette='viridis',
            order=x_order,
            hue_order=hue_order
        )

        # --- THE FINAL, BULLETPROOF FIX ---
        
        # 1. Create a lookup map for the data we actually have in this subset
        data_map = {
            (row['training_type'], row['model_name']): row
            for _, row in subset_df.iterrows()
        }

        # 2. Iterate through the bar containers that seaborn actually created.
        #    Each container corresponds to one hue (one model).
        for container in ax.containers:
            # 3. Get the ground-truth label for this container. This is the model name.
            hue_cat = container.get_label()
            
            # 4. Iterate through each bar within this specific container.
            #    The bars are ordered according to the x_order list.
            for j, bar in enumerate(container):
                # The j-th bar corresponds to the j-th x-category
                x_cat = x_order[j]
                
                # Look up the data for this specific bar (e.g., Qwen on 'random' data)
                data_row = data_map.get((x_cat, hue_cat))
                
                # If a bar has data (i.e., it's not an empty placeholder), plot its error bar.
                if data_row is not None:
                    correct = data_row['correct_predictions']
                    total = data_row['total_samples']
                    lower_err, upper_err = calculate_ci(correct, total)
                    
                    # Error array for a single bar must be shape (2, 1)
                    err = np.array([[lower_err], [upper_err]])
                    
                    # Get the bar's center coordinate for plotting
                    x_coord = bar.get_x() + bar.get_width() / 2.
                    y_coord = bar.get_height()
                    
                    ax.errorbar(x=x_coord, y=y_coord, yerr=err, fmt='none', c='black', capsize=3)

        # --- END OF FIX ---

        plt.title(f'Best Checkpoint Performance ({mode.replace("_", " ").title()})', fontsize=16)
        plt.ylabel('Overall Accuracy (%)', fontsize=12)
        plt.xlabel('Training Data Type', fontsize=12)
        plt.xticks(rotation=15, ha='right')
        plt.ylim(0, 105)
        plt.legend(title='Model', loc='upper left')
        plt.tight_layout()
        
        plot_path = os.path.join(output_dir, f'best_checkpoint_performance_{mode}.png')
        plt.savefig(plot_path)
        print(f"Saved plot to {plot_path}")
        plt.close()

def plot_checkpoint_progression(df, experiment_prefix, output_dir="plots"):
    """
    Tracks and plots the growth of stats across checkpoints for a given
    experimental prefix.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # --- FIX: Reconstruct the full prefix consistently, including experiment_mode ---
    df['full_prefix'] = df['model_name'] + '_' + df['representation_mode'] + '_' + df['training_type'] + '_' + df['experiment_mode']
    
    subset_df = df[df['full_prefix'] == experiment_prefix].copy()
    
    if subset_df.empty:
        print(f"Warning: No data found for prefix '{experiment_prefix}'. Skipping plot.")
        return

    # Sort by checkpoint step for a clean line plot
    subset_df = subset_df.sort_values('checkpoint_step')

    plt.figure(figsize=(14, 7))
    sns.lineplot(data=subset_df, x='checkpoint_step', y='overall_accuracy', marker='o', label='Overall Accuracy')
    
    title_prefix = experiment_prefix.replace('_', ' ').replace('  ', ' ')
    plt.title(f'Performance Progression for:\n{title_prefix}', fontsize=16)
    plt.ylabel('Overall Accuracy (%)', fontsize=12)
    plt.xlabel('Checkpoint Step', fontsize=12)
    plt.legend()
    plt.grid(True, which='both', linestyle='--')
    plt.tight_layout()
    
    safe_prefix = re.sub(r'[^a-zA-Z0-9_-]', '_', experiment_prefix)
    plot_path = os.path.join(output_dir, f'progression_{safe_prefix}.png')
    plt.savefig(plot_path)
    print(f"Saved plot to {plot_path}")
    plt.close()


def plot_accuracy_vs_complexity(best_df, output_dir="plots"):
    """
    For each best model, plots its accuracy against game complexity metrics.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for _, row in best_df.iterrows():
        stats = row['fine_grained_stats']['by_minimax_outcome_score']
        if not stats: continue

        plot_data = pd.DataFrame.from_dict(stats, orient='index')
        plot_data['accuracy'] = (plot_data['correct'] / plot_data['total'] * 100).fillna(0)
        plot_data['outcome_score'] = plot_data.index.astype(int)
        
        def categorize_score(s):
            if s > 2: return "Fast Win"
            if s > 0: return "Slow Win"
            if s == 0: return "Draw"
            if s > -3: return "Slow Loss"
            return "Fast Loss"
        
        plot_data['complexity_category'] = plot_data['outcome_score'].apply(categorize_score)
        category_order = ["Fast Win", "Slow Win", "Draw", "Slow Loss", "Fast Loss"]
        
        plt.figure(figsize=(12, 7))
        sns.barplot(data=plot_data, x='complexity_category', y='accuracy', palette='coolwarm', order=category_order, hue='complexity_category', dodge=False)
        
        model_info = f"{row['model_name']} ({row['training_type']})"
        plt.title(f'Accuracy vs. Strategic Complexity for {model_info}\n(Task: {row["experiment_mode"].title()})', fontsize=16)
        plt.ylabel('Accuracy (%)', fontsize=12)
        plt.xlabel('Outcome Category (Complexity)', fontsize=12)
        plt.ylim(0, 105)
        plt.tight_layout()

        safe_name = re.sub(r'[^a-zA-Z0-9_-]', '_', model_info)
        plot_path = os.path.join(output_dir, f'complexity_vs_accuracy_{safe_name}_{row["experiment_mode"]}.png')
        plt.savefig(plot_path)
        print(f"Saved plot to {plot_path}")
        plt.close()
        
# In analyze_results.py, add these functions to the "Plotting Functions" section.

def plot_performance_heatmap(best_df, output_dir="plots"):
    """
    Creates a heatmap to show the best accuracy for each Model vs. Training Type.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for mode in sorted(best_df['experiment_mode'].unique()):
        subset_df = best_df[best_df['experiment_mode'] == mode]
        if subset_df.empty:
            continue
            
        # Pivot the data to create a matrix suitable for a heatmap
        pivot_table = subset_df.pivot_table(
            index='model_name', 
            columns='training_type', 
            values='overall_accuracy'
        )
        
        plt.figure(figsize=(10, 7))
        sns.heatmap(
            pivot_table, 
            annot=True,       # Show the accuracy values in the cells
            fmt=".2f",        # Format values to two decimal places
            cmap="viridis",   # Color scheme
            linewidths=.5,
            cbar_kws={'label': 'Best Accuracy (%)'}
        )
        
        plt.title(f'Heatmap of Best Model Performance ({mode.replace("_", " ").title()})', fontsize=16)
        plt.ylabel('Model', fontsize=12)
        plt.xlabel('Training Data Type', fontsize=12)
        plt.xticks(rotation=10)
        plt.yticks(rotation=0)
        plt.tight_layout()
        
        plot_path = os.path.join(output_dir, f'performance_heatmap_{mode}.png')
        plt.savefig(plot_path)
        print(f"Saved plot to {plot_path}")
        plt.close()


def plot_precision_under_pressure(best_df, output_dir="plots"):
    """
    Plots model accuracy specifically on states where there is only one best move.
    This focuses on the 'best_move' experiment mode.
    """
    os.makedirs(output_dir, exist_ok=True)
    mode = 'best_move'
    
    subset_df = best_df[best_df['experiment_mode'] == mode].copy()
    if subset_df.empty:
        print("Warning: No 'best_move' data to generate 'Precision Under Pressure' plot.")
        return
        
    # Calculate the accuracy for the num_best_moves == 1 case
    def get_critical_accuracy(row):
        stats = row.get('fine_grained_stats', {}).get('by_num_best_moves', {}).get('1', {})
        if stats and stats['total'] > 0:
            return (stats['correct'] / stats['total']) * 100
        return 0
        
    subset_df['critical_accuracy'] = subset_df.apply(get_critical_accuracy, axis=1)

    plt.figure(figsize=(12, 8))
    sns.barplot(
        data=subset_df,
        x='training_type',
        y='critical_accuracy',
        hue='model_name',
        palette='magma'
    )
    
    plt.title('Precision Under Pressure: Accuracy on Single-Best-Move States', fontsize=16)
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.xlabel('Training Data Type', fontsize=12)
    plt.xticks(rotation=15, ha='right')
    plt.ylim(0, 105)
    plt.legend(title='Model', loc='upper left')
    plt.tight_layout()

    plot_path = os.path.join(output_dir, f'precision_under_pressure.png')
    plt.savefig(plot_path)
    print(f"Saved plot to {plot_path}")
    plt.close()


def plot_distraction_index(best_df, output_dir="plots"):
    """
    Calculates and plots a "distraction index": the performance drop when
    moving from low-choice states to high-choice states.
    This focuses on the 'best_move' experiment mode.
    """
    os.makedirs(output_dir, exist_ok=True)
    mode = 'best_move'

    subset_df = best_df[best_df['experiment_mode'] == mode].copy()
    if subset_df.empty:
        print("Warning: No 'best_move' data to generate 'Distraction Index' plot.")
        return

    def get_accuracy_by_legal_moves(row, num_moves_list):
        total_correct = 0
        total_samples = 0
        stats_by_moves = row.get('fine_grained_stats', {}).get('by_num_legal_moves', {})
        for num in num_moves_list:
            stats = stats_by_moves.get(str(num), {})
            if stats:
                total_correct += stats.get('correct', 0)
                total_samples += stats.get('total', 0)
        
        return (total_correct / total_samples) * 100 if total_samples > 0 else 0

    # Low choices = 1, 2, 3 legal moves
    subset_df['low_choice_accuracy'] = subset_df.apply(
        lambda row: get_accuracy_by_legal_moves(row, [1, 2, 3]), axis=1)
    
    # High choices = 6, 7, 8 legal moves
    subset_df['high_choice_accuracy'] = subset_df.apply(
        lambda row: get_accuracy_by_legal_moves(row, [6, 7, 8]), axis=1)
        
    subset_df['distraction_drop'] = subset_df['low_choice_accuracy'] - subset_df['high_choice_accuracy']

    plt.figure(figsize=(12, 8))
    # Sort by the drop for a more intuitive plot
    subset_df = subset_df.sort_values('distraction_drop', ascending=False)
    
    ax = sns.barplot(
        data=subset_df,
        x='model_name',
        y='distraction_drop',
        hue='training_type',
        palette='plasma',
        dodge=True
    )
    
    plt.title('Model Distraction Index (Performance Drop with More Choices)', fontsize=16)
    plt.ylabel('Accuracy Drop (Low-Choice states vs. High-Choice states)', fontsize=12)
    plt.xlabel('Model', fontsize=12)
    plt.xticks(rotation=15, ha='right')
    plt.tight_layout()
    
    plot_path = os.path.join(output_dir, f'distraction_index.png')
    plt.savefig(plot_path)
    print(f"Saved plot to {plot_path}")
    plt.close()

#################################################################
# 4. Main Runner Function
#################################################################

# In analyze_results.py, update the main runner function

def analyze_and_plot(results_folder, plots_output_dir="plots"):
    """
    Main runner function to perform all analysis and generate plots.
    """
    print("--- Loading and Parsing Results ---")
    full_df = load_and_parse_results(results_folder)
    if full_df.empty:
        print("No results found. Exiting.")
        return
        
    print("\n--- Finding Best Checkpoints ---")
    best_checkpoints_df = find_best_checkpoints(full_df)
    print("Best performing checkpoints identified:")
    print(best_checkpoints_df[['model_name', 'training_type', 'experiment_mode', 'checkpoint_step', 'overall_accuracy']].round(2))

    print("\n--- Generating Plots ---")
    
    # --- Calling all plotting functions ---
    plot_best_checkpoint_performance(best_checkpoints_df, plots_output_dir)
    plot_accuracy_vs_complexity(best_checkpoints_df, plots_output_dir)
    
    # Call the new analysis plots
    plot_performance_heatmap(best_checkpoints_df, plots_output_dir)
    plot_precision_under_pressure(best_checkpoints_df, plots_output_dir)
    plot_distraction_index(best_checkpoints_df, plots_output_dir)
    
    # --- Generating Progression Plots ---
    print("\n--- Generating Progression Plots for All Experiments ---")
    if not best_checkpoints_df.empty:
        best_checkpoints_df['full_prefix'] = best_checkpoints_df['model_name'] + '_' + best_checkpoints_df['representation_mode'] + '_' + best_checkpoints_df['training_type']
        prefixes_to_plot = best_checkpoints_df['full_prefix'].unique()
        
        for prefix in prefixes_to_plot:
            mode = best_checkpoints_df[best_checkpoints_df['full_prefix'] == prefix]['experiment_mode'].iloc[0]
            plot_checkpoint_progression(full_df, prefix + '_' + mode, plots_output_dir)
            
    print("\n--- Analysis Complete ---")


if __name__ == '__main__':
    RESULTS_FOLDER_PATH = "/mnt/shared/stlm-logic/results"
    PLOTS_OUTPUT_DIR = "/home/data/stlm-game-logic/analysis_plots"
    
    analyze_and_plot(RESULTS_FOLDER_PATH, PLOTS_OUTPUT_DIR)