"""
Evaluates pass@n rate for the triangle discovery task and generates plots.

For each of the n triangle graphs, generate n samples and calculate pass@n rate.
(Default n=3 evaluates on 3 graphs used in RLFT)

Args:
- Checkpoint path
- Optional plotting arguments (--plot, --plot_file, --plot_title)

Output:
- Define your data as dictionaries with n values as keys and lists of pass@n values as values
- Each list contains pass@n values from multiple seeds/runs
- Example: {1: [48.4, 47.2, 49.1], 5: [76.2, 75.8, 76.9], 10: [83.6, 84.1, 83.2]}
- Generate and display pass@n plots with error bars
"""

import os
import argparse
import json
import numpy as np
import torch
import torch.nn as nn
from typing import Dict, List, Any, Tuple
import random
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

from utils import (
    TriangleTokenizer, TransformerPolicy, MODEL_PATH, FINE_TUNED_MODEL_PATH,
    is_valid_triangle, load_json, dataset_dir, DATA_ROOT, HASH_STR_LEN, T_TRIANGLES,
    parse_triangle_sequence, make_env, set_seed, generate
)


def load_model_and_tokenizer(model_path: str, device: str = "cpu"):
    """Load model and tokenizer from saved checkpoint."""
    model_dict = torch.load(model_path, map_location=device)

    tokenizer = TriangleTokenizer(
        entities=model_dict["tokenizer"]["entities"],
        special_tokens=model_dict["tokenizer"]["special_tokens"]
    )

    model = TransformerPolicy(
        vocab_size=model_dict["config"]["vocab_size"],
        d_model=model_dict["config"]["d_model"],
        n_layer=model_dict["config"]["n_layer"],
        n_head=model_dict["config"]["n_head"],
        dim_ff=model_dict["config"]["dim_ff"],
        dropout=model_dict["config"]["dropout"],
        max_len=model_dict["config"]["max_len"],
        tie_weights=True
    )
    model.load_state_dict(model_dict["model"])
    model = model.to(device)
    model.eval()

    return model, tokenizer


def generate_triangle_sequence(model, tokenizer, graph_idx: int, device: str = "cpu", temperature: float = 1.0) -> List[int]:
    """Generate a triangle sequence for a specific graph."""
    # Create initial prompt: "<graph_idx> tri: "
    initial_prompt = f"{graph_idx} tri: "
    
    # Generate sequence using the correct function signature
    with torch.no_grad():
        generated_tokens = generate(
            policy=model,
            tokenizer=tokenizer,
            max_new_tokens=16,  # Should be enough for triangle generation
            prompt=initial_prompt,
            device=device,
            temperature=temperature
        )
    
    return generated_tokens  # generate() already returns a list


def is_valid_triangle_sequence(sequence: List[int], tokenizer: TriangleTokenizer, graph_edges: Dict) -> bool:
    """Check if a generated sequence forms a valid triangle in the given graph."""
    triangle_vertices = parse_triangle_sequence(sequence, tokenizer)
    
    if triangle_vertices is None:
        return False
    
    return is_valid_triangle(triangle_vertices, graph_edges)


def canonicalize_sequence(seq: str) -> str:
    """Canonicalize a triangle sequence by sorting the vertices."""
    # The seq looks like "<a_i><a_j><sep><a_j><a_k><sep><a_k><a_i>"
    # Convert it such that i,j,k are sorted
    seq = seq.split("<sep>")
    try:
        i = int(seq[0].split("<a_")[1].split(">")[0])
        j = int(seq[1].split("<a_")[1].split(">")[0])
        k = int(seq[2].split("<a_")[1].split(">")[0])
        smallest = min(i, j, k)
        largest = max(i, j, k)
        middle = i + j + k - smallest - largest
        return "<a_{}><a_{}><sep><a_{}><a_{}><sep><a_{}><a_{}>".format(smallest, middle, middle, largest, largest, smallest)
    except:
        print("Failed for canonicalizing:", seq)
        return seq


def is_unseen_triangle_sequence(sequence: List[int], tokenizer: TriangleTokenizer, graph_edges: Dict, train_sequences: List[str]) -> bool:
    """Check if a generated sequence is both valid AND unseen (creative)."""
    if not is_valid_triangle_sequence(sequence, tokenizer, graph_edges):
        return False
    
    try:
        # Decode the sequence
        decoded = tokenizer.decode(sequence)
        if "tri:" in decoded:
            triangle_part = decoded.split("tri:")[1].strip()
        else:
            triangle_part = decoded
        
        # Canonicalize and check if it's in training data
        canonicalized = canonicalize_sequence(triangle_part)
        return canonicalized not in train_sequences
    except Exception as e:
        print(f"Error checking unseen status: {e}")
        return False


def evaluate_pass_at_n_for_graph(
    model, 
    tokenizer, 
    graph_idx: int, 
    graph_edges: Dict, 
    n_samples: int, 
    train_sequences: List[str],
    device: str = "cpu",
    temperature: float = 1.0,
    use_unseen: bool = True
) -> bool:
    """
    Evaluate pass@n for a single graph.
    Returns True if at least one of the n samples is successful.
    
    Args:
        use_unseen: If True, success means valid AND unseen (creative).
                   If False, success means just valid.
    """
    for _ in range(n_samples):
        try:
            sequence = generate_triangle_sequence(model, tokenizer, graph_idx, device, temperature)
            
            if use_unseen:
                # Success = valid AND unseen (creative)
                if is_unseen_triangle_sequence(sequence, tokenizer, graph_edges, train_sequences):
                    return True
            else:
                # Success = just valid
                if is_valid_triangle_sequence(sequence, tokenizer, graph_edges):
                    return True
        except Exception as e:
            print(f"Error generating sequence for graph {graph_idx}: {e}")
            continue
    
    return False


def evaluate_pass_at_n(
    model_path: str,
    num_graphs: int = 10,
    n_values: List[int] = [1, 5, 10, 20, 40, 80],
    num_seeds: int = 3,
    seeds: List[int] = [],
    device: str = "cpu",
    temperature: float = 1.0,
    use_unseen: bool = True
) -> Dict[int, List[float]]:
    """
    Evaluate pass@n for different n values across multiple seeds.
    
    Args:
        model_path: Path to the model checkpoint
        n_values: List of n values to evaluate
        num_seeds: Number of random seeds to use
        seeds: Optional list of specific seeds to use
        device: Device to run on
        temperature: Sampling temperature
        use_unseen: If True, success means valid AND unseen (creative).
                   If False, success means just valid.
        
    Returns:
        Dictionary mapping n values to lists of pass@n rates
    """
    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_path, device)
    
    # Load all 10 graphs
    ddir = dataset_dir(DATA_ROOT, HASH_STR_LEN, T_TRIANGLES)
    graphs = []
    print(f"Using {num_graphs} triangle graphs")
    for i in range(num_graphs):
        graph_path = os.path.join(ddir, f"edges_{i}.json")
        if os.path.exists(graph_path):
            graph = load_json(graph_path)
            graphs.append(graph)
        else:
            raise FileNotFoundError(f"Could not find {graph_path}")
    
    # Load training sequences for unseen evaluation
    train_sequences = []
    if use_unseen:
        train_path = os.path.join(ddir, "train.json")
        if os.path.exists(train_path):
            with open(train_path, "r", encoding='utf-8') as f:
                train_data = json.load(f)
                for item in train_data:
                    if "tri:" in item["target_text"]:
                        seq = item["target_text"].split("tri:")[1].strip()
                        seq = canonicalize_sequence(seq)
                        train_sequences.append(seq)
        else:
            print("Warning: Could not find training data. Using valid-only evaluation.")
            use_unseen = False
    
    print(f"Loaded {len(graphs)} graphs")
    print(f"Loaded {len(train_sequences)} training sequences")
    print(f"Evaluating pass@n for n_values: {n_values}")
    print(f"Using {num_seeds} seeds")
    print(f"Evaluation mode: {'unseen (creative)' if use_unseen else 'valid only'}")
    
    results = {n: [] for n in n_values}

    for i in range(num_seeds):
        print(f"\n--- Seed {i + 1}/{num_seeds} ---")
        if seeds:
            seed = int(seeds[i % len(seeds)])
            print("Using seed:", seed)
            set_seed(seed)
        else:
            set_seed(999 + i)
        
        for n in n_values:
            # print(f"Evaluating pass@{n}...")
            passed_graphs = 0
            
            for graph_idx in range(len(graphs)):
                success = evaluate_pass_at_n_for_graph(
                    model=model,
                    tokenizer=tokenizer,
                    graph_idx=graph_idx,
                    graph_edges=graphs[graph_idx],
                    n_samples=n,
                    train_sequences=train_sequences,
                    device=device,
                    temperature=temperature,
                    use_unseen=use_unseen
                )
                
                if success:
                    passed_graphs += 1
                
                print(f"  Graph {graph_idx}: {'PASS' if success else 'FAIL'}")
            
            pass_rate = (passed_graphs / len(graphs)) * 100
            results[n].append(pass_rate)
            print(f"Pass@{n} rate: {pass_rate:.1f}% ({passed_graphs}/{len(graphs)})")
    
    return results


def generate_pass_at_n_plot(results: Dict[int, List[float]], output_file: str = None, title: str = "Pass@n Evaluation Results"):
    """
    Generate a pass@n plot from evaluation results.
    
    Args:
        results: Dictionary mapping n values to lists of pass@n rates
        output_file: Optional output file path for the plot
        title: Title for the plot
    """
    # Iris colors palette
    iris_colors = ["#134E6F", "#1AC0C6", "#FFA822", "#FF6150", "#DEE0E6", "#091A29"]
    
    # Create the plot
    plt.figure(figsize=(10, 8))
    
    # Get all unique n values and create equally spaced positions
    sorted_n_values = sorted(results.keys())
    x_positions = np.arange(len(sorted_n_values))
    
    # Calculate mean and std for each n
    pass_means = []
    pass_stds = []
    
    for n in sorted_n_values:
        seed_values = results[n]
        mean_val = np.mean(seed_values)
        std_val = np.std(seed_values)
        pass_means.append(mean_val)
        pass_stds.append(std_val)
    
    # Plot shaded region for standard deviation
    plt.fill_between(x_positions, 
                    np.array(pass_means) - np.array(pass_stds), 
                    np.array(pass_means) + np.array(pass_stds), 
                    alpha=0.1, 
                    color=iris_colors[0])
    
    # Plot mean line
    plt.plot(x_positions, pass_means, 
            linewidth=2, 
            marker='o', 
            markersize=6,
            color=iris_colors[0])
    
    # Customize the plot
    plt.xlabel('n (Pass@n)', fontsize=14)
    plt.ylabel('Pass Rate', fontsize=14)
    plt.title(title, fontsize=16)
    plt.grid(True, alpha=0.2)
    
    # Set axis limits
    plt.xlim(-0.2, len(sorted_n_values) - 0.8)
    if pass_means:
        min_val = min(pass_means) - 10
        max_val = max(pass_means) + 10
        plt.ylim(max(0, min_val), min(100, max_val))
    
    # Set x-axis ticks to show the actual n values
    plt.xticks(x_positions, sorted_n_values)
    
    # Format y-axis ticks with percentage signs
    def percentage_formatter(x, pos):
        return f'{x:.0f}%'
    
    plt.gca().yaxis.set_major_formatter(FuncFormatter(percentage_formatter))
    plt.tight_layout()
    
    # Save the plot if output file is specified
    if output_file:
        plt.savefig(output_file, format='pdf', dpi=300, bbox_inches='tight', pad_inches=0.1)
        print(f"Plot saved to: {output_file}")
    
    # Show the plot
    plt.show()


def generate_multi_model_plot(all_results: Dict[str, Dict[int, List[float]]], output_file: str = None, title: str = "Pass@n Model Comparison"):
    """
    Generate a pass@n comparison plot for multiple models.
    
    Args:
        all_results: Dictionary mapping model names to their results dictionaries
        output_file: Optional output file path for the plot
        title: Title for the plot
    """
    # Iris colors palette
    iris_colors = ["#134E6F", "#1AC0C6", "#FFA822", "#FF6150", "#DEE0E6", "#091A29"]
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    # Get all unique n values from all models
    all_n_values = set()
    for model_results in all_results.values():
        all_n_values.update(model_results.keys())
    
    sorted_n_values = sorted(all_n_values)
    x_positions = np.arange(len(sorted_n_values))
    
    # Create mapping from n value to x position
    n_to_position = {n: pos for pos, n in enumerate(sorted_n_values)}
    
    # Plot each model
    for i, (model_name, results) in enumerate(all_results.items()):
        if not results:  # Skip empty results
            continue
            
        # Calculate mean and std for each n
        pass_means = []
        pass_stds = []
        plot_positions = []
        
        for n in sorted_n_values:
            if n in results:
                seed_values = results[n]
                mean_val = np.mean(seed_values)
                std_val = np.std(seed_values)
                pass_means.append(mean_val)
                pass_stds.append(std_val)
                plot_positions.append(n_to_position[n])
            else:
                # Handle missing n values
                pass_means.append(0)
                pass_stds.append(0)
                plot_positions.append(n_to_position[n])
        
        # Plot shaded region for standard deviation
        plt.fill_between(plot_positions, 
                        np.array(pass_means) - np.array(pass_stds), 
                        np.array(pass_means) + np.array(pass_stds), 
                        alpha=0.1, 
                        color=iris_colors[i % len(iris_colors)])
        
        # Plot mean line
        plt.plot(plot_positions, pass_means, 
                label=model_name, 
                linewidth=2, 
                marker='o', 
                markersize=6,
                color=iris_colors[i % len(iris_colors)])
    
    # Customize the plot
    plt.xlabel('n (Pass@n)', fontsize=14)
    plt.ylabel('Pass Rate', fontsize=14)
    plt.title(title, fontsize=16)
    plt.legend(fontsize=12, loc='lower right')
    plt.grid(True, alpha=0.2)
    
    # Set axis limits
    plt.xlim(-0.2, len(sorted_n_values) - 0.8)
    plt.ylim(0, 100)
    
    # Set x-axis ticks to show the actual n values
    plt.xticks(x_positions, sorted_n_values)
    
    # Format y-axis ticks with percentage signs
    def percentage_formatter(x, pos):
        return f'{x:.0f}%'
    
    plt.gca().yaxis.set_major_formatter(FuncFormatter(percentage_formatter))
    plt.tight_layout()
    
    # Save the plot if output file is specified
    if output_file:
        plt.savefig(output_file, format='pdf', dpi=300, bbox_inches='tight', pad_inches=0.1)
        print(f"Multi-model plot saved to: {output_file}")
    
    # Show the plot
    plt.show()


def main():
    parser = argparse.ArgumentParser(description="Evaluate pass@n for triangle discovery task")
    parser.add_argument("--model_path", type=str, default=MODEL_PATH,
                        help="Path to model checkpoint")
    parser.add_argument("--num_graphs", type=int, default=10,
                        help="Number of triangle graphs to use (3 or 10)")
    parser.add_argument("--n_values", type=int, nargs="+", default=[1, 5, 10, 20, 40, 80],
                        help="List of n values to evaluate")
    parser.add_argument("--num_seeds", type=int, default=3,
                        help="Number of random seeds to use")
    parser.add_argument("--seeds", type=str, nargs="+", default=[],
                        help="List of seeds")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device to run on")
    parser.add_argument("--temperature", type=float, default=1.0,
                        help="Sampling temperature")
    parser.add_argument("--use_unseen", action="store_true",
                        help="Use unseen evaluation (valid AND creative). If False, use valid-only evaluation.")
    parser.add_argument("--use_valid_only", action="store_true", default=False,
                        help="Use valid-only evaluation (overrides --use_unseen)")
    parser.add_argument("--output_file", type=str, default=None,
                        help="Output file to save results (JSON format)")
    parser.add_argument("--plot", action="store_true", default=False,
                        help="Generate and display a pass@n plot")
    parser.add_argument("--plot_file", type=str, default=None,
                        help="Output file for the plot (PDF format)")
    parser.add_argument("--plot_title", type=str, default="Pass@n Evaluation Results",
                        help="Title for the generated plot")
    parser.add_argument("--multi_model", action="store_true", default=False,
                        help="Enable multi-model comparison mode")
    parser.add_argument("--model_paths", type=str, nargs="+", default=[],
                        help="List of model paths for multi-model comparison")
    parser.add_argument("--model_names", type=str, nargs="+", default=[],
                        help="List of model names for multi-model comparison (must match model_paths)")
    parser.add_argument("--comparison_plot", action="store_true", default=False,
                        help="Generate comparison plot for multiple models")
    parser.add_argument("--comparison_plot_file", type=str, default=None,
                        help="Output file for the comparison plot (PDF format)")
    
    args = parser.parse_args()
    
    # Determine evaluation mode
    use_unseen = args.use_unseen and not args.use_valid_only
    
    # Handle multi-model evaluation
    if args.multi_model:
        if not args.model_paths:
            print("Error: --model_paths must be provided when using --multi_model")
            return
        
        if args.model_names and len(args.model_names) != len(args.model_paths):
            print("Error: --model_names must have the same length as --model_paths")
            return
        
        # Use model names or generate them from paths
        if not args.model_names:
            model_names = [os.path.basename(path).replace('.pt', '') for path in args.model_paths]
        else:
            model_names = args.model_names
        
        print(f"Multi-model evaluation mode")
        print(f"Models: {list(zip(model_names, args.model_paths))}")
        print(f"Device: {args.device}")
        print(f"Temperature: {args.temperature}")
        print(f"Evaluation mode: {'unseen (creative)' if use_unseen else 'valid only'}")
        
        # Evaluate all models
        all_results = {}
        for model_name, model_path in zip(model_names, args.model_paths):
            print(f"\n{'='*60}")
            print(f"Evaluating: {model_name} ({model_path})")
            print(f"{'='*60}")
            
            results = evaluate_pass_at_n(
                model_path=model_path,
                num_graphs=args.num_graphs,
                n_values=args.n_values,
                num_seeds=args.num_seeds,
                seeds=args.seeds,
                device=args.device,
                temperature=args.temperature,
                use_unseen=use_unseen
            )
            all_results[model_name] = results
        
        # Generate comparison plot
        if args.comparison_plot:
            print(f"\n{'='*60}")
            print("GENERATING COMPARISON PLOT")
            print(f"{'='*60}")
            generate_multi_model_plot(
                all_results=all_results,
                output_file=args.comparison_plot_file,
                title="Pass@n Model Comparison"
            )
        
        # Print comparison results
        print(f"\n{'='*60}")
        print("MULTI-MODEL COMPARISON RESULTS")
        print(f"{'='*60}")
        
        for n in args.n_values:
            print(f"\nPass@{n:2d}:")
            for model_name in model_names:
                if n in all_results[model_name]:
                    values = all_results[model_name][n]
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    print(f"  {model_name:15s}: {mean_val:5.1f}% ± {std_val:4.1f}% (seeds: {values})")
          
        # output results to json
        if args.output_file:
            with open(args.output_file, 'w') as f:
                json.dump(all_results, f, indent=2)
            print(f"\nResults saved to: {args.output_file}")
        
        return
    
    # Single model evaluation
    print(f"Evaluating model: {args.model_path}")
    print(f"Device: {args.device}")
    print(f"Temperature: {args.temperature}")
    print(f"Evaluation mode: {'unseen (creative)' if use_unseen else 'valid only'}")
    
    # Run evaluation
    results = evaluate_pass_at_n(
        model_path=args.model_path,
        num_graphs=args.num_graphs,
        n_values=args.n_values,
        num_seeds=args.num_seeds,
        seeds=args.seeds,
        device=args.device,
        temperature=args.temperature,
        use_unseen=use_unseen
    )
    
    # Print results
    print("\n" + "="*50)
    print("PASS@N EVALUATION RESULTS")
    print("="*50)
    
    for n in args.n_values:
        values = results[n]
        mean_val = np.mean(values)
        std_val = np.std(values)
        print(f"Pass@{n:2d}: {mean_val:5.1f}% ± {std_val:4.1f}% (seeds: {values})")
    
    # Save results if requested
    if args.output_file:
        output_data = {
            "model_path": args.model_path,
            "n_values": args.n_values,
            "num_seeds": args.num_seeds,
            "seeds": args.seeds,
            "device": args.device,
            "temperature": args.temperature,
            "results": results
        }
        
        with open(args.output_file, 'w') as f:
            json.dump(output_data, f, indent=2)
        print(f"\nResults saved to: {args.output_file}")
    
    # Generate plot if requested
    if args.plot:
        print("\n" + "="*50)
        print("GENERATING PLOT")
        print("="*50)
        generate_pass_at_n_plot(
            results=results,
            output_file=args.plot_file,
            title=args.plot_title
        )
    
    # print("\n" + "="*50)
    # print("DATA FOR PLOTTING SCRIPT")
    # print("="*50)
    # print("Copy this data into your plotting script:")
    # print()
    # print(f"data = {{")
    # for n in args.n_values:
    #     values = results[n]
    #     print(f"    {n}: {values},")
    # print(f"}}")
    # print()


if __name__ == "__main__":
    main()