import json
import torch
import numpy as np
from collections import defaultdict
import argparse
import os
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import warnings
from tqdm import tqdm

# Suppress warnings
warnings.filterwarnings("ignore")


def load_model_and_adapter(base_model_name, adapter_path=None, device="cuda"):
    """
    Load base model and optionally LoRA adapter.
    
    Args:
        base_model_name: Name of the base model (e.g., "Qwen/Qwen2.5-3B-Instruct")
        adapter_path: Path to the saved LoRA adapter (None to use base model only)
        device: Device to load model on
        
    Returns:
        tuple: (model, tokenizer)
    """
    print(f"Loading base model: {base_model_name}")
    
    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.float32,
        trust_remote_code=True,
        device_map="auto"
    )
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        base_model_name,
        trust_remote_code=True
    )
    tokenizer.pad_token = tokenizer.eos_token
    
    # Load LoRA adapter if path is provided and exists
    if adapter_path is not None:
        if os.path.exists(adapter_path):
            print(f"Loading LoRA adapter from: {adapter_path}")
            model = PeftModel.from_pretrained(model, adapter_path)
            print("LoRA adapter loaded successfully")
        else:
            print(f"Warning: Adapter path '{adapter_path}' not found. Using base model only.")
    else:
        print("No adapter path provided. Using base model only.")
    
    model.eval()
    return model, tokenizer


def load_prompt_data(prompt_file):
    """
    Load questions and possible answers from prompt file.
    
    Args:
        prompt_file: Path to the prompt JSON file
        
    Returns:
        tuple: (questions, possible_answers)
    """
    with open(prompt_file, 'r') as f:
        data = json.load(f)
    
    questions = data['x']
    possible_answers = data['y']
    
    print(f"Loaded {len(questions)} questions and {len(possible_answers)} possible answers")
    return questions, possible_answers


def get_answer_token_ids(tokenizer, possible_answers):
    """
    Get the token IDs for each possible answer.
    
    Args:
        tokenizer: Model tokenizer
        possible_answers: List of possible answer strings
        
    Returns:
        dict: Mapping from answer to token ID(s)
    """
    answer_tokens = {}
    
    for answer in possible_answers:
        # Tokenize the answer (without special tokens)
        tokens = tokenizer.encode(answer, add_special_tokens=False)
        
        # For single-word answers, we expect a single token in most cases
        if len(tokens) == 1:
            answer_tokens[answer] = tokens[0]
        else:
            # If multiple tokens, use the first one (most common case)
            answer_tokens[answer] = tokens[0]
            print(f"Warning: Answer '{answer}' tokenized to {len(tokens)} tokens: {tokens}. Using first token: {tokens[0]}")
    
    print(f"Answer token mapping: {answer_tokens}")
    return answer_tokens


def get_answer_logits(model, tokenizer, question, answer_tokens, temperature=1.0):
    """
    Get raw logits for each possible answer token given a question.
    
    Args:
        model: Loaded model with adapter
        tokenizer: Model tokenizer
        question: Input question string
        answer_tokens: Dict mapping answers to token IDs
        temperature: Temperature for scaling (applied later)
        
    Returns:
        dict: Raw logits for each answer
    """
    # Format the prompt
    prompt = f"{question} Answer with one word only:\n"
    
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=40)
    
    # Move to device
    if torch.cuda.is_available():
        inputs = {k: v.cuda() for k, v in inputs.items()}
    
    # Get model logits
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        
        # Get logits for the next token (after the prompt)
        next_token_logits = logits[0, -1, :]  # Shape: [vocab_size]
        
        # Apply temperature scaling
        if temperature != 1.0:
            next_token_logits = next_token_logits / temperature
    
    # Extract logits for our target tokens
    answer_logits = {}
    for answer, token_id in answer_tokens.items():
        logit_value = next_token_logits[token_id].item()
        answer_logits[answer] = logit_value
    
    # Debug: Check for problematic values
    if any(np.isnan(val) or np.isinf(val) for val in answer_logits.values()):
        print(f"WARNING: Found nan/inf in raw logits for question: {question[:50]}...")
        print(f"Raw logits: {answer_logits}")
        print(f"Token IDs: {answer_tokens}")
    
    return answer_logits


def compute_answer_probabilities(model, tokenizer, question, answer_tokens, temperature=1.0):
    """
    Compute probabilities for each possible answer token given a question.
    (Kept for backward compatibility, but now uses get_answer_logits)
    
    Args:
        model: Loaded model with adapter
        tokenizer: Model tokenizer
        question: Input question string
        answer_tokens: Dict mapping answers to token IDs
        temperature: Temperature for softmax
        
    Returns:
        dict: Probabilities for each answer
    """
    # Get logits
    logits = get_answer_logits(model, tokenizer, question, answer_tokens, temperature)
    
    # Convert to probabilities
    logits_tensor = torch.tensor(list(logits.values()))
    probs_tensor = torch.softmax(logits_tensor, dim=0)
    
    # Return as dictionary
    answer_probs = {answer: probs_tensor[i].item() for i, answer in enumerate(logits.keys())}
    return answer_probs


def compute_all_question_logits(model, tokenizer, questions, answer_tokens, temperature=1.0):
    """
    Compute logits for all questions once (since model is deterministic).
    
    Args:
        model: Loaded model with adapter
        tokenizer: Model tokenizer
        questions: List of questions
        answer_tokens: Dict mapping answers to token IDs
        temperature: Temperature for scaling
        
    Returns:
        list: List of logits dictionaries, one per question
    """
    print(f"Computing logits for {len(questions)} questions...")
    
    all_question_logits = []
    
    for q_idx, question in enumerate(tqdm(questions, desc="Computing logits")):
        # Get logits for this question
        logits = get_answer_logits(model, tokenizer, question, answer_tokens, temperature)
        all_question_logits.append(logits)
    
    return all_question_logits


def evaluate_model(all_question_logits, questions, possible_answers):
    """
    Evaluate model using pre-computed logits.
    
    Process:
    1. For each question, compute a policy (probability distribution) from logits using softmax
    2. Print each question's policy
    3. Average these question-level policies to get final averaged policy
    
    Args:
        all_question_logits: Pre-computed logits for all questions
        questions: List of questions
        possible_answers: List of possible answer options
        
    Returns:
        dict: Evaluation results including per-question and averaged policies
    """
    print(f"\nComputing policies for {len(questions)} questions...")
    
    # Step 1: Compute policy (softmax) for each question
    question_policies = []
    question_details = []
    
    for q_idx, (question, logits) in enumerate(zip(questions, all_question_logits)):
        # Check for nan or inf values in logits
        logits_values = list(logits.values())
        if any(np.isnan(val) or np.isinf(val) for val in logits_values):
            print(f"  WARNING: Found nan/inf in logits for question {q_idx}: {logits_values}")
            # Use uniform distribution as fallback
            question_policy = {answer: 1.0/len(possible_answers) for answer in possible_answers}
        else:
            # Convert logits to policy (probability distribution) for this question
            logits_tensor = torch.tensor(logits_values)
            policy_tensor = torch.softmax(logits_tensor, dim=0)
            question_policy = {answer: policy_tensor[i].item() for i, answer in enumerate(possible_answers)}
        
        question_policies.append(question_policy)
        
        # Print individual question policy
        print(f"\nQuestion {q_idx + 1}: {question[:60]}{'...' if len(question) > 60 else ''}")
        print(f"  Logits: {logits}")
        sorted_policy = sorted(question_policy.items(), key=lambda x: x[1], reverse=True)
        print(f"  Policy: {dict(sorted_policy)}")
        print(f"  Top answer: {sorted_policy[0][0]} ({sorted_policy[0][1]:.4f})")
        
        # Store question details
        question_details.append({
            'question': question,
            'question_index': q_idx,
            'individual_probabilities': question_policy,
            'logits': logits,
            'top_answer': sorted_policy[0]
        })
    
    # Step 2: Average the policies across all questions
    print(f"\n" + "="*60)
    print("COMPUTING AVERAGED POLICY ACROSS ALL QUESTIONS")
    print("="*60)
    
    averaged_policy = {}
    for answer in possible_answers:
        answer_probs = [policy[answer] for policy in question_policies]
        averaged_policy[answer] = np.mean(answer_probs)
        print(f"{answer}: {np.mean(answer_probs):.6f} (individual: {[f'{p:.3f}' for p in answer_probs]})")
    
    # Verify policy sums to 1.0
    policy_sum = sum(averaged_policy.values())
    if abs(policy_sum - 1.0) > 1e-6:
        print(f"WARNING: Policy sum is {policy_sum:.6f}, should be 1.0")
    
    print(f"\nFinal averaged policy: {averaged_policy}")
    
    results = {
        'questions': questions,
        'possible_answers': possible_answers,
        'averaged_policy': averaged_policy,
        'question_policies': question_policies,  # Individual question policies
        'question_details': question_details,
        'most_likely_answer_overall': max(averaged_policy.items(), key=lambda x: x[1])
    }
    
    return results




def save_evaluation_results(results, output_file, model_info=None):
    """
    Save evaluation results to JSON file.
    
    Args:
        results: Evaluation results dictionary
        output_file: Output file path
        model_info: Optional model information to include
    """
    # Add metadata
    output_data = {
        'metadata': {
            'timestamp': datetime.now().isoformat(),
            'model_info': model_info or {},
            'evaluation_summary': {
                'total_questions': len(results['questions']),
                'temperature': model_info.get('temperature', 1.0) if model_info else 1.0,
                'evaluation_type': 'token_probability_based'
            }
        },
        'results': results
    }
    
    # Save to file
    with open(output_file, 'w') as f:
        json.dump(output_data, f, indent=2)
    
    print(f"\nEvaluation results saved to: {output_file}")
    
    # Print summary
    print("\nEvaluation Summary:")
    print(f"  - Total questions: {len(results['questions'])}")
    print(f"  - Temperature: {model_info.get('temperature', 1.0) if model_info else 1.0}")
    print(f"  - Overall most likely answer: {results['most_likely_answer_overall'][0]} ({results['most_likely_answer_overall'][1]:.6f})")
    
    # Print top answers with probabilities
    sorted_answers = sorted(results['averaged_policy'].items(), key=lambda x: x[1], reverse=True)
    print(f"  - Averaged policy distribution:")
    for answer, prob in sorted_answers:
        print(f"    {answer}: {prob:.6f}")


def main():
    """
    Main evaluation function.
    """
    parser = argparse.ArgumentParser(description="Evaluate synthetic trained model using token probabilities")
    parser.add_argument('--model_path', type=str, default=None,
                        help='Path to saved LoRA adapter (optional - if not provided, uses base model only)')
    parser.add_argument('--base_model', type=str, default='Qwen/Qwen2.5-3B-Instruct',
                        help='Base model name')
    parser.add_argument('--prompt_file', type=str, default='datasets/synthetic/color_prompt.json',
                        help='Path to prompt file')
    parser.add_argument('--output', type=str, default='synthetic_evaluation_results_v2.json',
                        help='Output results file')
    parser.add_argument('--temperature', type=float, default=1.0,
                        help='Temperature for probability computation (1.0 for raw probabilities)')
    
    args = parser.parse_args()
    
    print("=" * 70)
    print("Synthetic Model Evaluation v2 (Token Probability Based)")
    print("=" * 70)
    
    # Check if files exist
    if not os.path.exists(args.prompt_file):
        print(f"Error: Prompt file '{args.prompt_file}' not found!")
        return
    
    # Load model and adapter
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    model, tokenizer = load_model_and_adapter(args.base_model, args.model_path, device)
    
    # Load questions and possible answers
    questions, possible_answers = load_prompt_data(args.prompt_file)
    
    # Get token IDs for each possible answer
    answer_tokens = get_answer_token_ids(tokenizer, possible_answers)
    
    # Compute logits for all questions
    all_question_logits = compute_all_question_logits(
        model, tokenizer, questions, answer_tokens, temperature=args.temperature
    )
    
    # Evaluate model (compute per-question policies and average them)
    results = evaluate_model(all_question_logits, questions, possible_answers)
    
    # Save results
    model_info = {
        'base_model': args.base_model,
        'adapter_path': args.model_path if args.model_path else 'None (base model only)',
        'using_adapter': args.model_path is not None,
        'temperature': args.temperature,
        'answer_tokens': answer_tokens
    }
    
    save_evaluation_results(results, args.output, model_info)
    
    print("\n" + "=" * 70)
    print("Evaluation completed!")
    print("=" * 70)


if __name__ == "__main__":
    main()
