"""
Metrics Computation Module for Evaluation

- Win Rate: Evaluates model performance against reference models using both sampled and probability-based metrics
- PPR: Assesses model's distribution of responses across conciseness levels using entropy and EMD
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
from tqdm import tqdm
import os
from peft import PeftModel, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import wandb


def compute_win_rate(
    model, tokenizer,
    reference_model, reference_tokenizer,
    annotation_model, annotation_tokenizer,
    test_dataset,
    batch_size=8,
    max_length=1030,
    w=0.6
):
    """
    Compute win rate by comparing model responses against reference model using annotation model.
    
    Args:
        model: Model to evaluate
        tokenizer: Tokenizer for text processing
        reference_model: Reference model for comparison
        reference_tokenizer: Tokenizer for reference model
        annotation_model: Model used for annotation/comparison
        annotation_tokenizer: Tokenizer for annotation model
        test_dataset: Dataset containing prompts and responses
        batch_size: Number of samples to process in each batch
        max_length: Maximum sequence length for tokenization
        w: Weight parameter for context selection
        
    Returns:
        tuple: (sampled_win_rate, prob_win_rate) - Two metrics of model performance
    """
    model.eval()
    reference_model.eval()
    annotation_model.eval()
    
    # Initialize counters
    wins = 0
    total = 0
    total_prob = 0.0  # For probability-based win rate
    total_valid_samples = 0  # Track total valid samples
    total_invalid_samples = 0  # Track total invalid samples
    
    # Context prompts
    #group1_context = "Generate a response that can be easily understood by an elementary school student."
    #group2_context = "Generate a response that only a PhD Student in that specific field could understand."
    #group1_context = "Generate a response that is friendly, witty, funny, and humorous, like a close friend."
    #group2_context = "Generate a response (that answers) in an unfriendly manner." 
    group1_context = "Generate a response that is concise and to the point, without being verbose."
    group2_context = "Generate a response that is very informative, without missing any background information."
    
    with torch.no_grad():
        pbar = tqdm(range(0, len(test_dataset), batch_size), desc="Computing win rate")
        for batch_idx, i in enumerate(pbar):
            end = min(i + batch_size, len(test_dataset))
            batch = test_dataset[i:end]
            batch_prompts = []
            batch_responses = []
            batch_ref_responses = []
            batch_contexts = []
            
            # Prepare batch data
            for sample in batch:
                x = sample['x']
                formatted_input = f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n"
                batch_prompts.append(formatted_input)
                
                # Choose context based on probability w
                use_group1 = torch.rand(1).item() < w
                context = group1_context if use_group1 else group2_context
                batch_contexts.append(context)
            
            # Batch process model responses with mixed precision
            with torch.cuda.amp.autocast():
                inputs = tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=max_length
                ).to(model.device)
                
                responses = model.generate(
                    **inputs,
                    max_new_tokens=max_length,
                    num_return_sequences=1,
                    pad_token_id=tokenizer.pad_token_id
                )
                batch_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
            
            # Batch process reference model responses with mixed precision
            with torch.cuda.amp.autocast():
                ref_inputs = reference_tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=max_length
                ).to(reference_model.device)
                
                ref_responses = reference_model.generate(
                    **ref_inputs,
                    max_new_tokens=max_length,
                    num_return_sequences=1,
                    pad_token_id=reference_tokenizer.pad_token_id
                )
                batch_ref_responses = reference_tokenizer.batch_decode(ref_responses, skip_special_tokens=True)
            
            # Clean responses and prompts by removing chat template markers
            clean_responses = []
            clean_ref_responses = []
            clean_prompts = []
            for prompt, resp, ref_resp in zip(batch_prompts, batch_responses, batch_ref_responses):
                # Clean prompt
                clean_prompt = prompt.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                clean_prompts.append(clean_prompt)
                
                # Clean responses
                clean_resp = resp.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                clean_ref_resp = ref_resp.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                
                # Additional cleaning to remove "user ... assistant" text if present
                if "user" in clean_resp and "assistant" in clean_resp:
                    try:
                        clean_resp = clean_resp.split("assistant", 1)[1].strip()
                    except:
                        pass
                
                if "user" in clean_ref_resp and "assistant" in clean_ref_resp:
                    try:
                        clean_ref_resp = clean_ref_resp.split("assistant", 1)[1].strip()
                    except:
                        pass
                
                clean_responses.append(clean_resp)
                clean_ref_responses.append(clean_ref_resp)
            
            # Process responses in smaller sub-batches to avoid OOM
            sub_batch_size = 8
            batch_wins = 0
            batch_total_prob = 0.0
            batch_valid_samples = 0
            batch_invalid_samples = 0
            
            for j in range(0, len(clean_responses), sub_batch_size):
                sub_clean_responses = clean_responses[j:j+sub_batch_size]
                sub_clean_ref_responses = clean_ref_responses[j:j+sub_batch_size]
                sub_batch_contexts = batch_contexts[j:j+sub_batch_size]
                sub_clean_prompts = clean_prompts[j:j+sub_batch_size]
                
                # Create comparison prompts for the annotation model
                comparison_prompts = []
                for prompt, context, resp1, resp2 in zip(sub_clean_prompts, sub_batch_contexts, sub_clean_responses, sub_clean_ref_responses):
                    # Format the comparison prompt without chat template markers
                    comparison_prompt = f"Context: {context}\n\nPrompt: {prompt}\n\nResponse 1: {resp1}\n\nResponse 2: {resp2}\n\nWhich response better matches the context? Please answer with only '1' or '2'."
                    comparison_prompts.append(comparison_prompt)
                
                try:
                    # Get logits from the annotation model
                    with torch.cuda.amp.autocast():
                        # Add chat template markers for the annotation model
                        formatted_comparison_prompts = [
                            f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n{prompt}<|start_header_id|>assistant<|end_header_id|>\n"
                            for prompt in comparison_prompts
                        ]
                        
                        inputs_comparison = annotation_tokenizer(
                            formatted_comparison_prompts,
                            return_tensors="pt",
                            padding=True,
                            truncation=True,
                            max_length=max_length
                        ).to(annotation_model.device)
                        

                        
                        # Generate comparison responses for sampled-based win rate
                        comparison_outputs = annotation_model.generate(
                            **inputs_comparison,
                            max_new_tokens=4,  # Very short response needed
                            num_return_sequences=1,
                            pad_token_id=annotation_tokenizer.pad_token_id,
                            temperature=0.1  # Low temperature for more deterministic outputs
                        )
                        comparison_outputs_only = comparison_outputs[:, inputs_comparison["input_ids"].shape[1]:]
                        comparison_responses = annotation_tokenizer.batch_decode(
                            comparison_outputs_only, 
                            skip_special_tokens=True
                        )
                        
                    
                    # Process comparison responses for sampled-based win rate
                    sub_batch_wins = 0
                    sub_batch_valid = 0
                    sub_batch_invalid = 0
                    for idx, comparison in enumerate(comparison_responses):
                        # Clean up the comparison response
                        clean_comparison = comparison.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                        
                        # Extract just the model's response part
                        if "<|im_start|>assistant\n" in clean_comparison:
                            clean_comparison = clean_comparison.split("<|im_start|>assistant\n", 1)[1].strip()
                        
                        # Try up to 5 times to get a clear response
                        chosen_response = None
                        max_attempts = 5
                        attempt = 0
                        
                        while chosen_response is None and attempt < max_attempts:
                            # Extract just the chosen response number - look for the last line that contains a number
                            for line in clean_comparison.split('\n'):
                                line = line.strip().lower()
                                if '1' in line:
                                    chosen_response = 1
                                    break
                                elif '2' in line:
                                    chosen_response = 2
                                    break
                            
                            
                            # If no clear response, generate a new one
                            if chosen_response is None and attempt < max_attempts - 1:
                                with torch.cuda.amp.autocast():
                                    new_output = annotation_model.generate(
                                        **inputs_comparison,
                                        max_new_tokens=4,
                                        num_return_sequences=1,
                                        pad_token_id=annotation_tokenizer.pad_token_id,
                                        temperature=0.1
                                    )
                                    new_output_only = new_output[:, inputs_comparison["input_ids"].shape[1]:]
                                    clean_comparison = annotation_tokenizer.batch_decode(
                                        new_output_only,
                                        skip_special_tokens=True
                                    )[0]
                                    # Clean up the new comparison response
                                    clean_comparison = clean_comparison.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                                    if "<|im_start|>assistant\n" in clean_comparison:
                                        clean_comparison = clean_comparison.split("<|im_start|>assistant\n", 1)[1].strip()
                            
                            attempt += 1
                        
                        # Track valid and invalid samples
                        if chosen_response is not None:
                            sub_batch_valid += 1
                            # Check if model's response was chosen (1)
                            if chosen_response == 1:
                                sub_batch_wins += 1
                        else:
                            sub_batch_invalid += 1
                        
                    
                    batch_wins += sub_batch_wins
                    batch_valid_samples += sub_batch_valid
                    batch_invalid_samples += sub_batch_invalid
                
                except Exception as e:
                    print(f"Error comparing responses: {e}")
                    continue
            
            # Update counters
            wins += batch_wins
            total_prob += batch_total_prob
            total += batch_valid_samples
            total_valid_samples += batch_valid_samples
            total_invalid_samples += batch_invalid_samples
            
            # Update progress bar with both win rates
            current_sampled_rate = wins / total if total > 0 else 0
            current_prob_rate = total_prob / total if total > 0 else 0
            pbar.set_postfix({
                "sampled_rate": f"{current_sampled_rate:.4f}",
                "prob_rate": f"{current_prob_rate:.4f}",
                "batch_wins": f"{batch_wins}/{batch_valid_samples}",
                "valid/invalid": f"{total_valid_samples}/{total_invalid_samples}"
            })
    
    sampled_win_rate = wins / total if total > 0 else 0
    prob_win_rate = total_prob / total if total > 0 else 0
    
    print(f"\nFinal Results:")
    print(f"Sampled-based win Rate: {sampled_win_rate:.4f}")
    print(f"Probability-based win Rate: {prob_win_rate:.4f}")
    print(f"Total Wins: {wins}/{total_valid_samples}")
    print(f"Valid/Invalid Samples: {total_valid_samples}/{total_invalid_samples}")
    
    return sampled_win_rate, prob_win_rate

def compute_PPR(model, tokenizer, annotation_model, annotation_tokenizer, test_dataset, batch_size=16, max_length=500, w=0.6):
    """
    Compute PPR metrics by analyzing the distribution of responses across conciseness levels.
    
    Args:
        model: Model to evaluate
        tokenizer: Tokenizer for text processing
        annotation_model: Model used for conciseness classification
        annotation_tokenizer: Tokenizer for annotation model
        test_dataset: Dataset containing prompts and responses
        batch_size: Number of samples to process in each batch
        max_length: Maximum sequence length for tokenization
        w: Weight parameter for target distribution (w for concise, 1-w for informative)
        
    Returns:
        tuple: (current_dist, min_ratio) - Current distribution and minimum ratio between current and target probabilities
    """
    device = model.device
    annotation_device = annotation_model.device
    
    model.eval()
    annotation_model.eval()
    
    # Define the set of response types
    # LEVELS = ['elementary', 'phd']
    LEVELS = ['concise', 'informative']
    # LEVELS = ['friendly', 'unfriendly']    
    level_counts = {level: 0 for level in LEVELS}
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(range(0, len(test_dataset), batch_size), desc="Computing PPR")
        
        for batch_idx, i in enumerate(pbar):
            end = min(i + batch_size, len(test_dataset))
            batch = test_dataset[i:end]
            batch_prompts = []
            batch_responses = []
            
            # Prepare batch data
            for sample in batch:
                x = sample['x']
                formatted_input = f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n"                
                batch_prompts.append(formatted_input)
            
            # Batch process model responses with mixed precision
            with torch.cuda.amp.autocast():
                inputs = tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=max_length
                ).to(device)
                
                responses = model.generate(
                    **inputs,
                    max_new_tokens=max_length,
                    num_return_sequences=1,
                    pad_token_id=tokenizer.pad_token_id
                )
                batch_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
            
            # Clean responses and prompts by removing chat template markers
            clean_responses = []
            clean_prompts = []
            for prompt, resp in zip(batch_prompts, batch_responses):
                # Clean prompt
                clean_prompt = prompt.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                clean_prompts.append(clean_prompt)
                
                # Clean response
                clean_resp = resp.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                
                # Additional cleaning to remove "user ... assistant" text if present
                if "user" in clean_resp and "assistant" in clean_resp:
                    try:
                        clean_resp = clean_resp.split("assistant", 1)[1].strip()
                    except:
                        pass
                
                clean_responses.append(clean_resp)
            
            # Process responses in smaller sub-batches to avoid OOM
            sub_batch_size = 8
            
            for j in range(0, len(clean_responses), sub_batch_size):
                sub_clean_responses = clean_responses[j:j+sub_batch_size]
                sub_clean_prompts = clean_prompts[j:j+sub_batch_size]
                
                # Create classification prompts for the annotation model
                classification_prompts = []
                for prompt, resp in zip(sub_clean_prompts, sub_clean_responses):
                    #classification_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\nPrompt: {prompt}\n\nResponse: {resp}\n\nIs this response elementary or phd level? Please answer with only one of these exact options: 'elementary' or 'phd'.<|start_header_id|>assistant<|end_header_id|>\n"
                    classification_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\nPrompt: {prompt}\n\nResponse: {resp}\n\nIs this response concise (but less informative) or informative (but less concise)? Please answer with only one of these exact options: 'concise' or 'informative'.<|start_header_id|>assistant<|end_header_id|>\n"
                    #classification_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\nPrompt: {prompt}\n\nResponse: {resp}\n\nIs this response friendly or unfriendly? Please answer with only one of these exact options: 'friendly' or 'unfriendly'.<|start_header_id|>assistant<|end_header_id|>\n"

                    classification_prompts.append(classification_prompt)
                
                try:
                    # Get classifications from the annotation model
                    with torch.cuda.amp.autocast():
                        inputs_classification = annotation_tokenizer(
                            classification_prompts,
                            return_tensors="pt",
                            padding=True,
                            truncation=True,
                            max_length=max_length
                        ).to(annotation_device)
                        
                        # Generate classification responses
                        classification_outputs = annotation_model.generate(
                            **inputs_classification,
                            max_new_tokens=20,  # Short response needed
                            num_return_sequences=1,
                            pad_token_id=annotation_tokenizer.pad_token_id,
                            temperature=0.1  # Low temperature for more deterministic outputs
                        )
                        
                        classification_responses = annotation_tokenizer.batch_decode(
                            classification_outputs, 
                            skip_special_tokens=True
                        )
                    
                    # Process classification responses
                    for idx, classification in enumerate(classification_responses):
                        # Clean up the classification response
                        clean_classification = classification.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                        
                        # Extract just the model's response part
                        if "<|im_start|>assistant\n" in clean_classification:
                            clean_classification = clean_classification.split("<|im_start|>assistant\n", 1)[1].strip()
                        
                        # Try up to 5 times to get a valid level
                        classification_level = None
                        max_attempts = 5
                        attempt = 0
                        
                        while classification_level is None and attempt < max_attempts:
                            # Extract just the classification level - look for the last line that contains a level
                            for line in clean_classification.split('\n'):
                                line = line.strip().lower()
                                for level in LEVELS:
                                    if level in line:
                                        classification_level = level
                                        break
                                if classification_level:
                                    break
                            
                            # If no valid level, generate a new one
                            if classification_level is None and attempt < max_attempts - 1:
                                with torch.cuda.amp.autocast():
                                    new_output = annotation_model.generate(
                                        **inputs_classification,
                                        max_new_tokens=20,
                                        num_return_sequences=1,
                                        pad_token_id=annotation_tokenizer.pad_token_id,
                                        temperature=0.1
                                    )
                                    new_output_only = new_output[:, inputs_classification["input_ids"].shape[1]:]
                                    clean_classification = annotation_tokenizer.batch_decode(
                                        new_output_only,
                                        skip_special_tokens=True
                                    )[0]
                                    # Clean up the new classification response
                                    clean_classification = clean_classification.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                                    if "<|im_start|>assistant\n" in clean_classification:
                                        clean_classification = clean_classification.split("<|im_start|>assistant\n", 1)[1].strip()
                            
                            attempt += 1
                        
                        if classification_level:
                            level_counts[classification_level] += 1
                
                except Exception as e:
                    print(f"Error classifying responses: {e}")
                    # Skip this batch on error
                    continue
            
            # Update total count
            total += len(clean_responses)
            
            # Update progress bar
            pbar.set_postfix({"total": total})
    
    # Compute final distribution and ratio
    if total > 0:
        # Convert counts to tensor
        counts_tensor = torch.tensor([count for count in level_counts.values()], device=device)
        probs = counts_tensor / counts_tensor.sum()
        
        # Calculate target distribution
        target_dist = torch.tensor([w, 1-w], device=device)  # [0.6, 0.4] for concise/informative
        
        # Calculate minimum ratio between current and target probabilities
        ratios = probs / target_dist
        min_ratio = torch.min(ratios).item()
        
        print(f"\nFinal Distribution Metrics:")
        print(f"Current Distribution: {probs.tolist()}")
        print(f"Target Distribution: {target_dist.tolist()}")
        print(f"Minimum Ratio: {min_ratio:.4f}")
        
        # Log metrics to wandb
        wandb.log({
            "PPR/level_counts": {f"level_{i}": count for i, count in enumerate(level_counts)},
            "PPR/probabilities": {f"level_{i}": prob.item() for i, prob in enumerate(probs)},
            "PPR/min_ratio": min_ratio
        })
        
        # Convert tensors to Python native types before returning
        return probs.tolist(), min_ratio
    else:
        return [0.0, 0.0], 0.0



