import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
from tqdm import tqdm
import os
from peft import PeftModel, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import wandb


def compute_win_rate(
    model, tokenizer,
    reference_model, reference_tokenizer,
    annotation_model, annotation_tokenizer,
    test_dataset,
    batch_size=8,
    max_length=1030,
    w=0.6
):
    """
    Compute win rate by comparing model responses against reference model using annotation model.
    
    Args:
        model: Model to evaluate
        tokenizer: Tokenizer for text processing
        reference_model: Reference model for comparison
        reference_tokenizer: Tokenizer for reference model
        annotation_model: Model used for annotation/comparison
        annotation_tokenizer: Tokenizer for annotation model
        test_dataset: Dataset containing prompts and responses
        batch_size: Number of samples to process in each batch
        max_length: Maximum sequence length for tokenization
        w: Weight parameter for context selection
        
    Returns:
        tuple: (sampled_win_rate, prob_win_rate) - Two metrics of model performance
    """
    model.eval()
    reference_model.eval()
    annotation_model.eval()
    
    # Initialize counters
    wins = 0
    total = 0
    total_prob = 0.0  # For probability-based win rate
    total_valid_samples = 0  # Track total valid samples
    total_invalid_samples = 0  # Track total invalid samples
    
    # Context prompts
    #group1_context = "Generate a response that can be easily understood by an elementary school student."
    #group2_context = "Generate a response that only a PhD Student in that specific field could understand."
    #group1_context = "Generate a response that is friendly, witty, funny, and humorous, like a close friend."
    #group2_context = "Generate a response (that answers) in an unfriendly manner." 
    group1_context = "Generate a response that is concise and to the point, without being verbose."
    group2_context = "Generate a response that is very informative, without missing any background information."
    
    with torch.no_grad():
        pbar = tqdm(range(0, len(test_dataset), batch_size), desc="Computing win rate")
        for batch_idx, i in enumerate(pbar):
            end = min(i + batch_size, len(test_dataset))
            batch = test_dataset[i:end]
            batch_prompts = []
            batch_responses = []
            batch_ref_responses = []
            batch_contexts = []
            
            # Prepare batch data
            for sample in batch:
                x = sample['x']
                formatted_input = f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n"
                batch_prompts.append(formatted_input)
                
                # Choose context based on probability w
                use_group1 = torch.rand(1).item() < w
                context = group1_context if use_group1 else group2_context
                batch_contexts.append(context)
            
            # Batch process model responses with mixed precision
            with torch.cuda.amp.autocast():
                inputs = tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=max_length
                ).to(model.device)
                
                responses = model.generate(
                    **inputs,
                    max_new_tokens=max_length,
                    num_return_sequences=1,
                    pad_token_id=tokenizer.pad_token_id
                )
                batch_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
            
            # Batch process reference model responses with mixed precision
            with torch.cuda.amp.autocast():
                ref_inputs = reference_tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=max_length
                ).to(reference_model.device)
                
                ref_responses = reference_model.generate(
                    **ref_inputs,
                    max_new_tokens=max_length,
                    num_return_sequences=1,
                    pad_token_id=reference_tokenizer.pad_token_id
                )
                batch_ref_responses = reference_tokenizer.batch_decode(ref_responses, skip_special_tokens=True)
            
            # Clean responses and prompts by removing chat template markers
            clean_responses = []
            clean_ref_responses = []
            clean_prompts = []
            for prompt, resp, ref_resp in zip(batch_prompts, batch_responses, batch_ref_responses):
                # Clean prompt
                clean_prompt = prompt.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                clean_prompts.append(clean_prompt)
                
                # Clean responses
                clean_resp = resp.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                clean_ref_resp = ref_resp.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                
                # Additional cleaning to remove "user ... assistant" text if present
                if "user" in clean_resp and "assistant" in clean_resp:
                    try:
                        clean_resp = clean_resp.split("assistant", 1)[1].strip()
                    except:
                        pass
                
                if "user" in clean_ref_resp and "assistant" in clean_ref_resp:
                    try:
                        clean_ref_resp = clean_ref_resp.split("assistant", 1)[1].strip()
                    except:
                        pass
                
                clean_responses.append(clean_resp)
                clean_ref_responses.append(clean_ref_resp)
            
            # Process responses in smaller sub-batches to avoid OOM
            sub_batch_size = 8
            batch_wins = 0
            batch_total_prob = 0.0
            batch_valid_samples = 0
            batch_invalid_samples = 0
            
            for j in range(0, len(clean_responses), sub_batch_size):
                sub_clean_responses = clean_responses[j:j+sub_batch_size]
                sub_clean_ref_responses = clean_ref_responses[j:j+sub_batch_size]
                sub_batch_contexts = batch_contexts[j:j+sub_batch_size]
                sub_clean_prompts = clean_prompts[j:j+sub_batch_size]
                
                # Create comparison prompts for the annotation model
                comparison_prompts = []
                for prompt, context, resp1, resp2 in zip(sub_clean_prompts, sub_batch_contexts, sub_clean_responses, sub_clean_ref_responses):
                    # Format the comparison prompt without chat template markers
                    comparison_prompt = f"Context: {context}\n\nPrompt: {prompt}\n\nResponse 1: {resp1}\n\nResponse 2: {resp2}\n\nWhich response better matches the context? Please answer with only '1' or '2'."
                    comparison_prompts.append(comparison_prompt)
                
                try:
                    # Get logits from the annotation model
                    with torch.cuda.amp.autocast():
                        # Add chat template markers for the annotation model
                        formatted_comparison_prompts = [
                            f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n{prompt}<|start_header_id|>assistant<|end_header_id|>\n"
                            for prompt in comparison_prompts
                        ]
                        
                        inputs_comparison = annotation_tokenizer(
                            formatted_comparison_prompts,
                            return_tensors="pt",
                            padding=True,
                            truncation=True,
                            max_length=max_length
                        ).to(annotation_model.device)
                        

                        
                        # Generate comparison responses for sampled-based win rate
                        comparison_outputs = annotation_model.generate(
                            **inputs_comparison,
                            max_new_tokens=4,  # Very short response needed
                            num_return_sequences=1,
                            pad_token_id=annotation_tokenizer.pad_token_id,
                            temperature=0.1  # Low temperature for more deterministic outputs
                        )
                        comparison_outputs_only = comparison_outputs[:, inputs_comparison["input_ids"].shape[1]:]
                        comparison_responses = annotation_tokenizer.batch_decode(
                            comparison_outputs_only, 
                            skip_special_tokens=True
                        )
                        
                    
                    # Process comparison responses for sampled-based win rate
                    sub_batch_wins = 0
                    sub_batch_valid = 0
                    sub_batch_invalid = 0
                    for idx, comparison in enumerate(comparison_responses):
                        # Clean up the comparison response
                        clean_comparison = comparison.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                        
                        # Extract just the model's response part
                        if "<|im_start|>assistant\n" in clean_comparison:
                            clean_comparison = clean_comparison.split("<|im_start|>assistant\n", 1)[1].strip()
                        
                        # Try up to 5 times to get a clear response
                        chosen_response = None
                        max_attempts = 5
                        attempt = 0
                        
                        while chosen_response is None and attempt < max_attempts:
                            # Extract just the chosen response number - look for the last line that contains a number
                            for line in clean_comparison.split('\n'):
                                line = line.strip().lower()
                                if '1' in line:
                                    chosen_response = 1
                                    break
                                elif '2' in line:
                                    chosen_response = 2
                                    break
                            
                            
                            # If no clear response, generate a new one
                            if chosen_response is None and attempt < max_attempts - 1:
                                with torch.cuda.amp.autocast():
                                    new_output = annotation_model.generate(
                                        **inputs_comparison,
                                        max_new_tokens=4,
                                        num_return_sequences=1,
                                        pad_token_id=annotation_tokenizer.pad_token_id,
                                        temperature=0.1
                                    )
                                    new_output_only = new_output[:, inputs_comparison["input_ids"].shape[1]:]
                                    clean_comparison = annotation_tokenizer.batch_decode(
                                        new_output_only,
                                        skip_special_tokens=True
                                    )[0]
                                    # Clean up the new comparison response
                                    clean_comparison = clean_comparison.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                                    if "<|im_start|>assistant\n" in clean_comparison:
                                        clean_comparison = clean_comparison.split("<|im_start|>assistant\n", 1)[1].strip()
                            
                            attempt += 1
                        
                        # Track valid and invalid samples
                        if chosen_response is not None:
                            sub_batch_valid += 1
                            # Check if model's response was chosen (1)
                            if chosen_response == 1:
                                sub_batch_wins += 1
                        else:
                            sub_batch_invalid += 1
                        
                    
                    batch_wins += sub_batch_wins
                    batch_valid_samples += sub_batch_valid
                    batch_invalid_samples += sub_batch_invalid
                
                except Exception as e:
                    print(f"Error comparing responses: {e}")
                    continue
            
            # Update counters
            wins += batch_wins
            total_prob += batch_total_prob
            total += batch_valid_samples
            total_valid_samples += batch_valid_samples
            total_invalid_samples += batch_invalid_samples
            
            # Update progress bar with both win rates
            current_sampled_rate = wins / total if total > 0 else 0
            current_prob_rate = total_prob / total if total > 0 else 0
            pbar.set_postfix({
                "sampled_rate": f"{current_sampled_rate:.4f}",
                "prob_rate": f"{current_prob_rate:.4f}",
                "batch_wins": f"{batch_wins}/{batch_valid_samples}",
                "valid/invalid": f"{total_valid_samples}/{total_invalid_samples}"
            })
    
    sampled_win_rate = wins / total if total > 0 else 0
    prob_win_rate = total_prob / total if total > 0 else 0
    
    print(f"\nFinal Results:")
    print(f"Sampled-based win Rate: {sampled_win_rate:.4f}")
    print(f"Probability-based win Rate: {prob_win_rate:.4f}")
    print(f"Total Wins: {wins}/{total_valid_samples}")
    print(f"Valid/Invalid Samples: {total_valid_samples}/{total_invalid_samples}")
    
    return sampled_win_rate, prob_win_rate

def compute_PPR(model, tokenizer, annotation_model, annotation_tokenizer, test_dataset, batch_size=16, max_length=500, w=0.6):
    """
    Compute PPR metrics by analyzing the distribution of responses across conciseness levels.
    
    Args:
        model: Model to evaluate
        tokenizer: Tokenizer for text processing
        annotation_model: Model used for conciseness classification
        annotation_tokenizer: Tokenizer for annotation model
        test_dataset: Dataset containing prompts and responses
        batch_size: Number of samples to process in each batch
        max_length: Maximum sequence length for tokenization
        w: Weight parameter for target distribution (w for concise, 1-w for informative)
        
    Returns:
        tuple: (current_dist, min_ratio) - Current distribution and minimum ratio between current and target probabilities
    """
    device = model.device
    annotation_device = annotation_model.device
    
    model.eval()
    annotation_model.eval()
    
    # Define the set of response types
    # LEVELS = ['elementary', 'phd']
    LEVELS = ['concise', 'informative']
    # LEVELS = ['friendly', 'unfriendly']    
    level_counts = {level: 0 for level in LEVELS}
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(range(0, len(test_dataset), batch_size), desc="Computing PPR")
        
        for batch_idx, i in enumerate(pbar):
            end = min(i + batch_size, len(test_dataset))
            batch = test_dataset[i:end]
            batch_prompts = []
            batch_responses = []
            
            # Prepare batch data
            for sample in batch:
                x = sample['x']
                formatted_input = f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n"                
                batch_prompts.append(formatted_input)
            
            # Batch process model responses with mixed precision
            with torch.cuda.amp.autocast():
                inputs = tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=max_length
                ).to(device)
                
                responses = model.generate(
                    **inputs,
                    max_new_tokens=max_length,
                    num_return_sequences=1,
                    pad_token_id=tokenizer.pad_token_id
                )
                batch_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
            
            # Clean responses and prompts by removing chat template markers
            clean_responses = []
            clean_prompts = []
            for prompt, resp in zip(batch_prompts, batch_responses):
                # Clean prompt
                clean_prompt = prompt.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                clean_prompts.append(clean_prompt)
                
                # Clean response
                clean_resp = resp.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                
                # Additional cleaning to remove "user ... assistant" text if present
                if "user" in clean_resp and "assistant" in clean_resp:
                    try:
                        clean_resp = clean_resp.split("assistant", 1)[1].strip()
                    except:
                        pass
                
                clean_responses.append(clean_resp)
            
            # Process responses in smaller sub-batches to avoid OOM
            sub_batch_size = 8
            
            for j in range(0, len(clean_responses), sub_batch_size):
                sub_clean_responses = clean_responses[j:j+sub_batch_size]
                sub_clean_prompts = clean_prompts[j:j+sub_batch_size]
                
                # Create classification prompts for the annotation model
                classification_prompts = []
                for prompt, resp in zip(sub_clean_prompts, sub_clean_responses):
                    #classification_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\nPrompt: {prompt}\n\nResponse: {resp}\n\nIs this response elementary or phd level? Please answer with only one of these exact options: 'elementary' or 'phd'.<|start_header_id|>assistant<|end_header_id|>\n"
                    classification_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\nPrompt: {prompt}\n\nResponse: {resp}\n\nIs this response concise (but less informative) or informative (but less concise)? Please answer with only one of these exact options: 'concise' or 'informative'.<|start_header_id|>assistant<|end_header_id|>\n"
                    #classification_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\nPrompt: {prompt}\n\nResponse: {resp}\n\nIs this response friendly or unfriendly? Please answer with only one of these exact options: 'friendly' or 'unfriendly'.<|start_header_id|>assistant<|end_header_id|>\n"

                    classification_prompts.append(classification_prompt)
                
                try:
                    # Get classifications from the annotation model
                    with torch.cuda.amp.autocast():
                        inputs_classification = annotation_tokenizer(
                            classification_prompts,
                            return_tensors="pt",
                            padding=True,
                            truncation=True,
                            max_length=max_length
                        ).to(annotation_device)
                        
                        # Generate classification responses
                        classification_outputs = annotation_model.generate(
                            **inputs_classification,
                            max_new_tokens=20,  # Short response needed
                            num_return_sequences=1,
                            pad_token_id=annotation_tokenizer.pad_token_id,
                            temperature=0.1  # Low temperature for more deterministic outputs
                        )
                        
                        classification_responses = annotation_tokenizer.batch_decode(
                            classification_outputs, 
                            skip_special_tokens=True
                        )
                    
                    # Process classification responses
                    for idx, classification in enumerate(classification_responses):
                        # Clean up the classification response
                        clean_classification = classification.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                        
                        # Extract just the model's response part
                        if "<|im_start|>assistant\n" in clean_classification:
                            clean_classification = clean_classification.split("<|im_start|>assistant\n", 1)[1].strip()
                        
                        # Try up to 5 times to get a valid level
                        classification_level = None
                        max_attempts = 5
                        attempt = 0
                        
                        while classification_level is None and attempt < max_attempts:
                            # Extract just the classification level - look for the last line that contains a level
                            for line in clean_classification.split('\n'):
                                line = line.strip().lower()
                                for level in LEVELS:
                                    if level in line:
                                        classification_level = level
                                        break
                                if classification_level:
                                    break
                            
                            # If no valid level, generate a new one
                            if classification_level is None and attempt < max_attempts - 1:
                                with torch.cuda.amp.autocast():
                                    new_output = annotation_model.generate(
                                        **inputs_classification,
                                        max_new_tokens=20,
                                        num_return_sequences=1,
                                        pad_token_id=annotation_tokenizer.pad_token_id,
                                        temperature=0.1
                                    )
                                    new_output_only = new_output[:, inputs_classification["input_ids"].shape[1]:]
                                    clean_classification = annotation_tokenizer.batch_decode(
                                        new_output_only,
                                        skip_special_tokens=True
                                    )[0]
                                    # Clean up the new classification response
                                    clean_classification = clean_classification.replace("<|im_start|>user\n", "").replace("<|im_end|>\n<|im_start|>assistant\n", "").strip()
                                    if "<|im_start|>assistant\n" in clean_classification:
                                        clean_classification = clean_classification.split("<|im_start|>assistant\n", 1)[1].strip()
                            
                            attempt += 1
                        
                        if classification_level:
                            level_counts[classification_level] += 1
                
                except Exception as e:
                    print(f"Error classifying responses: {e}")
                    # Skip this batch on error
                    continue
            
            # Update total count
            total += len(clean_responses)
            
            # Update progress bar
            pbar.set_postfix({"total": total})
    
    # Compute final distribution and ratio
    if total > 0:
        # Convert counts to tensor
        counts_tensor = torch.tensor([count for count in level_counts.values()], device=device)
        probs = counts_tensor / counts_tensor.sum()
        
        # Calculate target distribution
        target_dist = torch.tensor([w, 1-w], device=device)  # [0.6, 0.4] for concise/informative
        
        # Calculate minimum ratio between current and target probabilities
        ratios = probs / target_dist
        min_ratio = torch.min(ratios).item()
        
        print(f"\nFinal Distribution Metrics:")
        print(f"Current Distribution: {probs.tolist()}")
        print(f"Target Distribution: {target_dist.tolist()}")
        print(f"Minimum Ratio: {min_ratio:.4f}")
        
        # Log metrics to wandb
        wandb.log({
            "PPR/level_counts": {f"level_{i}": count for i, count in enumerate(level_counts)},
            "PPR/probabilities": {f"level_{i}": prob.item() for i, prob in enumerate(probs)},
            "PPR/min_ratio": min_ratio
        })
        
        # Convert tensors to Python native types before returning
        return probs.tolist(), min_ratio
    else:
        return [0.0, 0.0], 0.0



