import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from datetime import datetime
import glob
import seaborn as sns


def get_latest_results_file(results_dir):
    """
    Find the latest results file in a directory based on timestamp.
    """
    result_files = list(results_dir.glob("rule_recall_summary_*.pkl"))  # Fixed typo
    if not result_files:
        print(f'''Warning: PLease note, no rule_recall_summary_ was found in {results_dir} ''')
        return None
    
    def extract_timestamp(f):
        try:
            # Handle full timestamp (20250512_054216) or just time (054216)
            timestamp_str = f.stem.split('_')[-1]
            if len(timestamp_str) == 6:  # Just time (HHMMSS)
                return datetime.strptime(timestamp_str, "%H%M%S")
            else:  # Full timestamp (YYYYMMDD_HHMMSS)
                date_part, time_part = timestamp_str.split('_')
                return datetime.strptime(f"{date_part}-{time_part}", "%Y%m%d-%H%M%S")
        except ValueError as e:
            print(f"Warning: Could not parse timestamp from {f.name}: {e}")
            return datetime.min
    
    latest_file = max(result_files, key=extract_timestamp)
    return latest_file


def extract_opec_value(p_dir):
    """
    Extract OPEC value from directory path.
    
    Args:
        p_dir (Path): Path to the example directory
        
    Returns:
        int: OPEC value
    """
    parent_name = Path(p_dir).parent.name
    return int(parent_name.split('OPEC')[1].split('_')[0])

def load_metrics_from_dir(p_dir):
    """
    Load metrics from the latest results file in a directory.
    Handles cases where the pickle file contains a single-row DataFrame
    that gets converted to a dictionary with index keys.
    """
    results_dir = p_dir / "results"
    if not results_dir.exists():
        print(f"Warning: Results directory not found in {p_dir}")
        return None
        
    latest_file = get_latest_results_file(results_dir)
    if not latest_file:
        return None
        
    try:
        with open(latest_file, 'rb') as f:
            metrics = pickle.load(f)
            
            # Handle case where metrics is a single-row DataFrame converted to dict
            if isinstance(metrics, dict):
                # If it's a dict of dicts (DataFrame case), extract the first value
                if all(isinstance(v, dict) for v in metrics.values()):
                    return {k: list(v.values())[0] for k, v in metrics.items()}
                # If it's already a flat dict, return as-is
                return metrics
                
            # Handle case where it might be a DataFrame directly
            if hasattr(metrics, 'to_dict'):
                metrics_dict = metrics.to_dict()
                return {k: list(v.values())[0] for k, v in metrics_dict.items()}
                
            return metrics
    except Exception as e:
        print(f"Error loading {latest_file}: {e}")
        return None


def extract_chain_len(p_dir):
    """
    Extract chain length from directory path.
    
    Args:
        p_dir (Path): Path to the example directory
        
    Returns:
        float: Chain length value
    """
    return float(p_dir.name.split('chain_len')[-1])

def plot_metrics(example_dirs, opec_values_to_explore, plot_rule_recovery=True):
    """
    Plot metrics with bold points for means and error bars for confidence intervals.
    Handles both numeric and non-numeric reasoning depths on x-axis.
    """
    # Convert to Path objects if needed
    example_dirs = [Path(p) if not isinstance(p, Path) else p for p in example_dirs]
    
    # Set up color palette
    colors = sns.color_palette("husl", len(opec_values_to_explore))
    color_map = {val: colors[i] for i, val in enumerate(sorted(opec_values_to_explore))}
    
    # Organize data by OPEC value and chain length
    data = {}
    numeric_labels = set()
    string_labels = set()
    
    for p_dir in example_dirs:
        opec_value = extract_opec_value(p_dir)
        if opec_value not in opec_values_to_explore:
            continue
            
        chain_len_str = Path(p_dir).name.split('chain_len')[-1]
        
        # Try to convert to float, otherwise keep as string
        try:
            chain_len = float(chain_len_str)
            numeric_labels.add(chain_len)
        except ValueError:
            chain_len = chain_len_str
            string_labels.add(chain_len_str)
            
        metrics = load_metrics_from_dir(p_dir)
        if metrics is None:
            print(f'Warning! No metrics in {p_dir}')
            continue
            
        if opec_value not in data:
            data[opec_value] = {}
        data[opec_value][chain_len] = metrics
    
    # Sort numeric and string labels separately
    sorted_numeric = sorted(numeric_labels)
    sorted_string = sorted(string_labels)
    
    # Combine labels - numerics first, then strings
    x_labels = sorted_numeric + sorted_string
    if x_labels and isinstance(x_labels[-1], str) and 'over' in x_labels[-1].lower():
        x_labels[-1] = '>9'


    if plot_rule_recovery:
        fig, axs = plt.subplots(2, 2, figsize=(14.5, 18))
        plot_configs = [
            ('rule_recall_fscore_mean', 'rule_recall_fscore_ci', 'F1-score', 'Rule Recovery'),
            ('rule_recall_precision_mean', 'rule_recall_precision_ci', 'Precision', 'Rule Recovery'),
            ('rule_recall_recall_mean', 'rule_recall_recall_ci', 'Recall', 'Rule Recovery'),
            ('query_completion_success_rate', 'query_completion_success_rate_ci', 'Success Rate', 'Query Completion')
        ]
    else:
        fig, ax = plt.subplots(figsize=(9,6))
        axs = np.array([[ax]])  # For consistent handling
        plot_configs = [
            ('query_completion_success_rate', 'query_completion_success_rate_ci', 'Success Rate', 'Query Completion')
        ]    
    # Create a single legend for all subplots
    legend_handles = []
    legend_labels = []
    
    for i, (ax, (metric, ci_metric, ylabel, title)) in enumerate(zip(axs.flat, plot_configs)):
        # ax.set_title(f'Task: ' + title, fontsize=30)
        ax.set_ylabel(ylabel, fontsize=30)
        ax.grid(True, alpha=0.3)
        
        # Only show x-axis labels on bottom row
        if plot_rule_recovery:
            if i >= 2:
                ax.set_xlabel('Reasoning Depth', fontsize=28)
            ax.set_xticklabels(x_labels if i >= 2 else ['']*len(x_labels))
        else:
            ax.set_xticklabels(x_labels, fontsize=16)
            ax.set_xlabel('Reasoning Depth', fontsize=28)
        
        # Create numeric positions for x-axis
        x_positions = np.arange(len(x_labels))
        ax.set_xticks(x_positions)
        ax.tick_params(axis='x', labelsize=22)
        ax.tick_params(axis='y', labelsize=22)
        # Calculate width for each OPEC value's points
        n_opec = len(data)
        total_width = 0.8  # Total width allocated for all OPEC values
        bar_width = total_width / n_opec
        offset = -total_width/2 + bar_width/2
        
        for opec_value in sorted(data.keys()):
            data[opec_value]['>9'] = data[opec_value]['Over9'] 
            color = color_map[opec_value]
            y_vals = []
            y_err = []  # Will store the half-width CI values
            present_positions = []
            
            for j, x_label in enumerate(x_labels):
                if x_label in data[opec_value]:
                    metrics = data[opec_value][x_label]
                    if metric in metrics and ci_metric in metrics:
                        y_vals.append(metrics[metric])
                        # CI value is already the half-width
                        y_err.append(metrics[ci_metric])
                        present_positions.append(j)
            
            # Calculate x positions for this OPEC value
            x_vals = [x_positions[j] + offset for j in present_positions]
            
            # Plot bold points for means
            line = ax.plot(x_vals, y_vals, 'o', 
                    color=color,
                    markersize=10,
                    markeredgewidth=2,
                    markeredgecolor='black',
                    label=f'OPEC {opec_value}')[0]
            
            # Only add to legend handles once
            if i == 0:
                legend_handles.append(line)
                legend_labels.append(f'OPEC {opec_value}')
            
            # Plot confidence intervals as error bars
            ax.errorbar(x_vals, y_vals,
                       yerr=np.array(y_err).T,  # Transpose for proper error bar formatting
                       fmt='none',
                       ecolor=color,
                       elinewidth=2,
                       capsize=5,
                       capthick=2,
                       alpha=0.7)
            
            offset += bar_width  # Move to next position
        
        ax.set_ylim(0, 1.05)
    
    # Create a single legend for the entire figure
    legend = fig.legend(handles=legend_handles,
                    labels=legend_labels,
                    loc='upper center',
                    bbox_to_anchor=(0.5, 1.0),  # Centered above the figure
                    ncol=len(legend_labels),     # One row for all entries
                    fontsize=25,
                    frameon=False)
    
    # Adjust layout to make room for both title and legend
    plt.tight_layout(rect=[0, 0, 1, 0.5])  # Leave space for legend
    plt.subplots_adjust(top=0.88)  # Adjust spacing to make room for legend
    plt.show()
        

