"""
Evaluates diff@k rate for the triangle discovery task and generates plots.

For each of the 10 triangle graphs, generate k samples and calculate diff@k rate.
Diff@k measures the number of unique and valid samples out of k attempts.

Args:
- Checkpoint path
- Optional plotting arguments (--plot, --plot_file, --plot_title)

Output:
- Define your data as dictionaries with k values as keys and lists of diff@k values as values
- Each list contains diff@k values from multiple seeds/runs
- Example: {1: [0.8, 0.7, 0.9], 5: [2.1, 2.0, 2.3], 10: [3.2, 3.1, 3.4]}
- Generate and display diff@k plots with error bars
"""

import os
import argparse
import json
import numpy as np
import torch
import torch.nn as nn
from typing import Dict, List, Any, Tuple
import random
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

from utils import (
    TriangleTokenizer, TransformerPolicy, MODEL_PATH, FINE_TUNED_MODEL_PATH,
    is_valid_triangle, load_json, dataset_dir, DATA_ROOT, HASH_STR_LEN, T_TRIANGLES,
    parse_triangle_sequence, make_env, set_seed, generate
)


def load_model_and_tokenizer(model_path: str, device: str = "cpu"):
    """Load model and tokenizer from saved checkpoint."""
    model_dict = torch.load(model_path, map_location=device)

    tokenizer = TriangleTokenizer(
        entities=model_dict["tokenizer"]["entities"],
        special_tokens=model_dict["tokenizer"]["special_tokens"]
    )

    model = TransformerPolicy(
        vocab_size=model_dict["config"]["vocab_size"],
        d_model=model_dict["config"]["d_model"],
        n_layer=model_dict["config"]["n_layer"],
        n_head=model_dict["config"]["n_head"],
        dim_ff=model_dict["config"]["dim_ff"],
        dropout=model_dict["config"]["dropout"],
        max_len=model_dict["config"]["max_len"],
        tie_weights=True
    )
    model.load_state_dict(model_dict["model"])
    model = model.to(device)
    model.eval()

    return model, tokenizer


def generate_triangle_sequence(model, tokenizer, graph_idx: int, device: str = "cpu", temperature: float = 1.0) -> List[int]:
    """Generate a triangle sequence for a specific graph."""
    # Create initial prompt: "<graph_idx> tri: "
    initial_prompt = f"{graph_idx} tri: "
    
    # Generate sequence using the correct function signature
    with torch.no_grad():
        generated_tokens = generate(
            policy=model,
            tokenizer=tokenizer,
            max_new_tokens=16,  # Should be enough for triangle generation
            prompt=initial_prompt,
            device=device,
            temperature=temperature
        )
    
    return generated_tokens  # generate() already returns a list


def is_valid_triangle_sequence(sequence: List[int], tokenizer: TriangleTokenizer, graph_edges: Dict) -> bool:
    """Check if a generated sequence forms a valid triangle in the given graph."""
    triangle_vertices = parse_triangle_sequence(sequence, tokenizer)
    
    if triangle_vertices is None:
        return False
    
    return is_valid_triangle(triangle_vertices, graph_edges)


def canonicalize_sequence(seq: str) -> str:
    """Canonicalize a triangle sequence by sorting the vertices."""
    # The seq looks like "<a_i><a_j><sep><a_j><a_k><sep><a_k><a_i>"
    # Convert it such that i,j,k are sorted
    seq = seq.split("<sep>")
    try:
        i = int(seq[0].split("<a_")[1].split(">")[0])
        j = int(seq[1].split("<a_")[1].split(">")[0])
        k = int(seq[2].split("<a_")[1].split(">")[0])
        smallest = min(i, j, k)
        largest = max(i, j, k)
        middle = i + j + k - smallest - largest
        return "<a_{}><a_{}><sep><a_{}><a_{}><sep><a_{}><a_{}>".format(smallest, middle, middle, largest, largest, smallest)
    except:
        print("Failed for canonicalizing:", seq)
        return seq


def is_unseen_triangle_sequence(sequence: List[int], tokenizer: TriangleTokenizer, graph_edges: Dict, train_sequences: List[str]) -> bool:
    """Check if a generated sequence is both valid AND unseen (creative)."""
    # First check if it's valid
    if not is_valid_triangle_sequence(sequence, tokenizer, graph_edges):
        return False
    
    # Then check if it's unseen
    try:
        # Decode the sequence to get the triangle part
        decoded = tokenizer.decode(sequence)
        if "tri:" in decoded:
            triangle_part = decoded.split("tri:")[1].strip()
        else:
            triangle_part = decoded
        
        # Canonicalize and check if it's in training data
        canonicalized = canonicalize_sequence(triangle_part)
        return canonicalized not in train_sequences
    except Exception as e:
        print(f"Error checking unseen status: {e}")
        return False


def evaluate_diff_at_k_for_graph(
    model, 
    tokenizer, 
    graph_idx: int, 
    graph_edges: Dict, 
    k_samples: int, 
    train_sequences: List[str],
    device: str = "cpu",
    temperature: float = 1.0,
    use_unseen: bool = True
) -> int:
    """
    Evaluate diff@k for a single graph.
    Returns the number of unique and valid samples out of k attempts.
    
    Args:
        use_unseen: If True, success means valid AND unseen (creative).
                   If False, success means just valid.
    """
    unique_valid_samples = set()
    
    for _ in range(k_samples):
        try:
            sequence = generate_triangle_sequence(model, tokenizer, graph_idx, device, temperature)
            
            # Decode sequence to get canonicalized representation
            decoded = tokenizer.decode(sequence)
            if "tri:" in decoded:
                triangle_part = decoded.split("tri:")[1].strip()
            else:
                triangle_part = decoded
            
            canonicalized = canonicalize_sequence(triangle_part)
            
            if use_unseen:
                # Success = valid AND unseen (creative)
                if is_unseen_triangle_sequence(sequence, tokenizer, graph_edges, train_sequences):
                    unique_valid_samples.add(canonicalized)
            else:
                # Success = just valid
                if is_valid_triangle_sequence(sequence, tokenizer, graph_edges):
                    unique_valid_samples.add(canonicalized)
        except Exception as e:
            print(f"Error generating sequence for graph {graph_idx}: {e}")
            continue
    
    return len(unique_valid_samples)


def evaluate_diff_at_k(
    model_path: str,
    num_graphs: int = 10,
    k_values: List[int] = [1, 5, 10, 20, 40, 80],
    num_seeds: int = 3,
    seeds: List[int] = [],
    device: str = "cpu",
    temperature: float = 1.0,
    use_unseen: bool = True
) -> Dict[int, List[float]]:
    """
    Evaluate diff@k for different k values across multiple seeds.
    
    Args:
        model_path: Path to the model checkpoint
        k_values: List of k values to evaluate
        num_seeds: Number of random seeds to use
        seeds: Optional list of specific seeds to use
        device: Device to run on
        temperature: Sampling temperature
        use_unseen: If True, success means valid AND unseen (creative).
                   If False, success means just valid.
        
    Returns:
        Dictionary mapping k values to lists of diff@k rates (average unique valid samples per graph)
    """
    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_path, device)
    
    # Load all 10 graphs
    ddir = dataset_dir(DATA_ROOT, HASH_STR_LEN, T_TRIANGLES)
    graphs = []
    for i in range(10):
        graph_path = os.path.join(ddir, f"edges_{i}.json")
        if os.path.exists(graph_path):
            graph = load_json(graph_path)
            graphs.append(graph)
        else:
            raise FileNotFoundError(f"Could not find {graph_path}")
    
    # Load training sequences for unseen evaluation
    train_sequences = []
    if use_unseen:
        train_path = os.path.join(ddir, "train.json")
        if os.path.exists(train_path):
            with open(train_path, "r", encoding='utf-8') as f:
                train_data = json.load(f)
                for item in train_data:
                    if "tri:" in item["target_text"]:
                        seq = item["target_text"].split("tri:")[1].strip()
                        seq = canonicalize_sequence(seq)
                        train_sequences.append(seq)
        else:
            print("Warning: Could not find training data. Using valid-only evaluation.")
            use_unseen = False
    
    print(f"Loaded {len(graphs)} graphs")
    print(f"Loaded {len(train_sequences)} training sequences")
    print(f"Evaluating diff@k for k_values: {k_values}")
    print(f"Using {num_seeds} seeds")
    print(f"Evaluation mode: {'unseen (creative)' if use_unseen else 'valid only'}")
    
    results = {k: [] for k in k_values}
    

    for i in range(num_seeds):
        print(f"\n--- Seed {i + 1}/{num_seeds} ---")
        if seeds:
            seed = int(seeds[i % len(seeds)])
            print("Using seed:", seed)
            set_seed(seed)
        else:
            set_seed(999 + i)
        
        for k in k_values:
            print(f"Evaluating diff@{k}...")
            total_unique_valid = 0
            
            for graph_idx in range(len(graphs)):
                unique_count = evaluate_diff_at_k_for_graph(
                    model=model,
                    tokenizer=tokenizer,
                    graph_idx=graph_idx,
                    graph_edges=graphs[graph_idx],
                    k_samples=k,
                    train_sequences=train_sequences,
                    device=device,
                    temperature=temperature,
                    use_unseen=use_unseen
                )
                
                total_unique_valid += unique_count
                print(f"  Graph {graph_idx}: {unique_count} unique valid samples")
            
            avg_diff_rate = total_unique_valid / len(graphs)
            results[k].append(avg_diff_rate)
            print(f"Diff@{k} rate: {avg_diff_rate:.2f} unique valid samples per graph (total: {total_unique_valid}/{len(graphs)})")
    
    return results


def generate_diff_at_k_plot(results: Dict[int, List[float]], output_file: str = None, title: str = "Diff@k Evaluation Results"):
    """
    Generate a diff@k plot from evaluation results.
    
    Args:
        results: Dictionary mapping k values to lists of diff@k rates
        output_file: Optional output file path for the plot
        title: Title for the plot
    """
    # Iris colors palette
    iris_colors = ["#134E6F", "#1AC0C6", "#FFA822", "#FF6150", "#DEE0E6", "#091A29"]
    
    # Create the plot
    plt.figure(figsize=(10, 8))
    
    # Get all unique k values and create equally spaced positions
    sorted_k_values = sorted(results.keys())
    x_positions = np.arange(len(sorted_k_values))
    
    # Calculate mean and std for each k
    diff_means = []
    diff_stds = []
    
    for k in sorted_k_values:
        seed_values = results[k]
        mean_val = np.mean(seed_values)
        std_val = np.std(seed_values)
        diff_means.append(mean_val)
        diff_stds.append(std_val)
    
    # Plot shaded region for standard deviation
    plt.fill_between(x_positions, 
                    np.array(diff_means) - np.array(diff_stds), 
                    np.array(diff_means) + np.array(diff_stds), 
                    alpha=0.1, 
                    color=iris_colors[0])
    
    # Plot mean line
    plt.plot(x_positions, diff_means, 
            linewidth=2, 
            marker='o', 
            markersize=6,
            color=iris_colors[0])
    
    # Customize the plot
    plt.xlabel('k (Diff@k)', fontsize=14)
    plt.ylabel('Unique Valid Samples per Graph', fontsize=14)
    plt.title(title, fontsize=16)
    plt.grid(True, alpha=0.2)
    
    # Set axis limits
    plt.xlim(-0.2, len(sorted_k_values) - 0.8)
    if diff_means:
        min_val = min(diff_means) - 0.5
        max_val = max(diff_means) + 0.5
        plt.ylim(max(0, min_val), max_val)
    
    # Set x-axis ticks to show the actual k values
    plt.xticks(x_positions, sorted_k_values)
    
    # Format y-axis ticks as numbers
    def number_formatter(x, pos):
        return f'{x:.1f}'
    
    plt.gca().yaxis.set_major_formatter(FuncFormatter(number_formatter))
    plt.tight_layout()
    
    # Save the plot if output file is specified
    if output_file:
        plt.savefig(output_file, format='pdf', dpi=300, bbox_inches='tight', pad_inches=0.1)
        print(f"Plot saved to: {output_file}")
    
    # Show the plot
    plt.show()


def generate_multi_model_plot(all_results: Dict[str, Dict[int, List[float]]], output_file: str = None, title: str = "Diff@k Model Comparison"):
    """
    Generate a diff@k comparison plot for multiple models.
    
    Args:
        all_results: Dictionary mapping model names to their results dictionaries
        output_file: Optional output file path for the plot
        title: Title for the plot
    """
    # Iris colors palette
    iris_colors = ["#134E6F", "#1AC0C6", "#FFA822", "#FF6150", "#DEE0E6", "#091A29"]
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    # Get all unique k values from all models
    all_k_values = set()
    for model_results in all_results.values():
        all_k_values.update(model_results.keys())
    
    sorted_k_values = sorted(all_k_values)
    x_positions = np.arange(len(sorted_k_values))
    
    # Create mapping from k value to x position
    k_to_position = {k: pos for pos, k in enumerate(sorted_k_values)}
    
    # Plot each model
    for i, (model_name, results) in enumerate(all_results.items()):
        if not results:  # Skip empty results
            continue
            
        # Calculate mean and std for each k
        diff_means = []
        diff_stds = []
        plot_positions = []
        
        for k in sorted_k_values:
            if k in results:
                seed_values = results[k]
                mean_val = np.mean(seed_values)
                std_val = np.std(seed_values)
                diff_means.append(mean_val)
                diff_stds.append(std_val)
                plot_positions.append(k_to_position[k])
            else:
                # Handle missing k values
                diff_means.append(0)
                diff_stds.append(0)
                plot_positions.append(k_to_position[k])
        
        # Plot shaded region for standard deviation
        plt.fill_between(plot_positions, 
                        np.array(diff_means) - np.array(diff_stds), 
                        np.array(diff_means) + np.array(diff_stds), 
                        alpha=0.1, 
                        color=iris_colors[i % len(iris_colors)])
        
        # Plot mean line
        plt.plot(plot_positions, diff_means, 
                label=model_name, 
                linewidth=2, 
                marker='o', 
                markersize=6,
                color=iris_colors[i % len(iris_colors)])
    
    # Customize the plot
    plt.xlabel('k (Diff@k)', fontsize=14)
    plt.ylabel('Unique Valid Samples per Graph', fontsize=14)
    plt.title(title, fontsize=16)
    plt.legend(fontsize=12, loc='lower right')
    plt.grid(True, alpha=0.2)
    
    # Set axis limits
    plt.xlim(-0.2, len(sorted_k_values) - 0.8)
    if diff_means:
        min_val = min(diff_means) - 0.5
        max_val = max(diff_means) + 0.5
        plt.ylim(max(0, min_val), max_val)
    
    # Set x-axis ticks to show the actual k values
    plt.xticks(x_positions, sorted_k_values)
    
    # Format y-axis ticks as numbers
    def number_formatter(x, pos):
        return f'{x:.1f}'
    
    plt.gca().yaxis.set_major_formatter(FuncFormatter(number_formatter))
    plt.tight_layout()
    
    # Save the plot if output file is specified
    if output_file:
        plt.savefig(output_file, format='pdf', dpi=300, bbox_inches='tight', pad_inches=0.1)
        print(f"Multi-model plot saved to: {output_file}")
    
    # Show the plot
    plt.show()


def main():
    parser = argparse.ArgumentParser(description="Evaluate diff@k for triangle discovery task")
    parser.add_argument("--model_path", type=str, default=MODEL_PATH,
                        help="Path to model checkpoint")
    parser.add_argument("--num_graphs", type=int, default=10,
                        help="Number of triangle graphs to use (3 or 10)")
    parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 40, 80],
                        help="List of k values to evaluate")
    parser.add_argument("--num_seeds", type=int, default=3,
                        help="Number of random seeds to use")
    parser.add_argument("--seeds", type=str, nargs="+", default=[],
                        help="List of seeds")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device to run on")
    parser.add_argument("--temperature", type=float, default=1.0,
                        help="Sampling temperature")
    parser.add_argument("--use_unseen", action="store_true", 
                        help="Use unseen evaluation (valid AND creative). If False, use valid-only evaluation.")
    parser.add_argument("--use_valid_only", action="store_true", default=False,
                        help="Use valid-only evaluation (overrides --use_unseen)")
    parser.add_argument("--output_file", type=str, default=None,
                        help="Output file to save results (JSON format)")
    parser.add_argument("--plot", action="store_true", default=False,
                        help="Generate and display a diff@k plot")
    parser.add_argument("--plot_file", type=str, default=None,
                        help="Output file for the plot (PDF format)")
    parser.add_argument("--plot_title", type=str, default="Diff@k Evaluation Results",
                        help="Title for the generated plot")
    parser.add_argument("--multi_model", action="store_true", default=False,
                        help="Enable multi-model comparison mode")
    parser.add_argument("--model_paths", type=str, nargs="+", default=[],
                        help="List of model paths for multi-model comparison")
    parser.add_argument("--model_names", type=str, nargs="+", default=[],
                        help="List of model names for multi-model comparison (must match model_paths)")
    parser.add_argument("--comparison_plot", action="store_true", default=False,
                        help="Generate comparison plot for multiple models")
    parser.add_argument("--comparison_plot_file", type=str, default=None,
                        help="Output file for the comparison plot (PDF format)")
    
    args = parser.parse_args()
    
    # Determine evaluation mode
    use_unseen = args.use_unseen and not args.use_valid_only
    
    # Handle multi-model evaluation
    if args.multi_model:
        if not args.model_paths:
            print("Error: --model_paths must be provided when using --multi_model")
            return
        
        if args.model_names and len(args.model_names) != len(args.model_paths):
            print("Error: --model_names must have the same length as --model_paths")
            return
        
        # Use model names or generate them from paths
        if not args.model_names:
            model_names = [os.path.basename(path).replace('.pt', '') for path in args.model_paths]
        else:
            model_names = args.model_names
        
        print(f"Multi-model evaluation mode")
        print(f"Models: {list(zip(model_names, args.model_paths))}")
        print(f"Device: {args.device}")
        print(f"Temperature: {args.temperature}")
        print(f"Evaluation mode: {'unseen (creative)' if use_unseen else 'valid only'}")
        
        # Evaluate all models
        all_results = {}
        for model_name, model_path in zip(model_names, args.model_paths):
            print(f"\n{'='*60}")
            print(f"Evaluating: {model_name} ({model_path})")
            print(f"{'='*60}")
            
            results = evaluate_diff_at_k(
                model_path=model_path,
                num_graphs=args.num_graphs,
                k_values=args.k_values,
                num_seeds=args.num_seeds,
                seeds=args.seeds,
                device=args.device,
                temperature=args.temperature,
                use_unseen=use_unseen
            )
            all_results[model_name] = results
        
        # Generate comparison plot
        if args.comparison_plot:
            print(f"\n{'='*60}")
            print("GENERATING COMPARISON PLOT")
            print(f"{'='*60}")
            generate_multi_model_plot(
                all_results=all_results,
                output_file=args.comparison_plot_file,
                title="Diff@k Model Comparison"
            )
        
        # Print comparison results
        print(f"\n{'='*60}")
        print("MULTI-MODEL COMPARISON RESULTS")
        print(f"{'='*60}")
        
        for k in args.k_values:
            print(f"\nDiff@{k:2d}:")
            for model_name in model_names:
                if k in all_results[model_name]:
                    values = all_results[model_name][k]
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    print(f"  {model_name:15s}: {mean_val:5.2f} ± {std_val:4.2f} unique valid samples per graph (seeds: {values})")
          
        # output results to json
        if args.output_file:
            with open(args.output_file, 'w') as f:
                json.dump(all_results, f, indent=2)
            print(f"\nResults saved to: {args.output_file}")
        
        return
    
    # Single model evaluation
    print(f"Evaluating model: {args.model_path}")
    print(f"Device: {args.device}")
    print(f"Temperature: {args.temperature}")
    print(f"Evaluation mode: {'unseen (creative)' if use_unseen else 'valid only'}")
    
    # Run evaluation
    results = evaluate_diff_at_k(
        model_path=args.model_path,
        num_graphs=args.num_graphs,
        k_values=args.k_values,
        num_seeds=args.num_seeds,
        seeds=args.seeds,
        device=args.device,
        temperature=args.temperature,
        use_unseen=use_unseen
    )
    
    # Print results
    print("\n" + "="*50)
    print("DIFF@K EVALUATION RESULTS")
    print("="*50)
    
    for k in args.k_values:
        values = results[k]
        mean_val = np.mean(values)
        std_val = np.std(values)
        print(f"Diff@{k:2d}: {mean_val:5.2f} ± {std_val:4.2f} unique valid samples per graph (seeds: {values})")
    
    # Save results if requested
    if args.output_file:
        output_data = {
            "model_path": args.model_path,
            "k_values": args.k_values,
            "num_seeds": args.num_seeds,
            "seeds": args.seeds,
            "device": args.device,
            "temperature": args.temperature,
            "results": results
        }
        
        with open(args.output_file, 'w') as f:
            json.dump(output_data, f, indent=2)
        print(f"\nResults saved to: {args.output_file}")
    
    # Generate plot if requested
    if args.plot:
        print("\n" + "="*50)
        print("GENERATING PLOT")
        print("="*50)
        generate_diff_at_k_plot(
            results=results,
            output_file=args.plot_file,
            title=args.plot_title
        )
    
    # print("\n" + "="*50)
    # print("DATA FOR PLOTTING SCRIPT")
    # print("="*50)
    # print("Copy this data into your plotting script:")
    # print()
    # print(f"data = {{")
    # for n in args.n_values:
    #     values = results[n]
    #     print(f"    {n}: {values},")
    # print(f"}}")
    # print()


if __name__ == "__main__":
    main()