"""
Evaluation Script

This module implements the evaluation pipeline for the following metrics:
- Win Rate: Compares model performance against reference model
- PPR: Evaluates model's distribution of responses

"""

import warnings
import os
import logging
import json
import yaml
import torch
import wandb
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from utils.compute_metrics import compute_win_rate, compute_PPR

# Configure logging to suppress warnings
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("peft").setLevel(logging.ERROR)

# Suppress specific warnings
warnings.filterwarnings("ignore", message=".*tokenizer.*deprecated.*")
warnings.filterwarnings("ignore", module="transformers.*")
warnings.filterwarnings("ignore", module="peft.*")

# Disable all warnings
warnings.simplefilter("ignore")

def load_config(config_path):
    """
    Load configuration from YAML file.
    
    Args:
        config_path: Path to the YAML configuration file
        
    Returns:
        dict: Configuration dictionary
    """
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def load_model_and_tokenizer(model_path, config, is_annotation_model=False):
    """
    Load model and tokenizer from specified path with appropriate settings.
    
    Args:
        model_path: Path to the model directory or model name
        config: Configuration dictionary containing model settings
        is_annotation_model: Whether this is an annotation model (affects tokenizer selection)
        
    Returns:
        tuple: (model, tokenizer) loaded and configured according to settings
    """
    # Load tokenizer from base model
    if is_annotation_model:
        # For annotation model, use its specific tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            config['annotation_model']['model_name'],
            trust_remote_code=config['shared']['model']['trust_remote_code'],
            token=config.get('auth', {}).get('hf_token', None)
        )
    else:
        # For main model, use the shared model's tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            config['shared']['model']['model_name'],
            trust_remote_code=config['shared']['model']['trust_remote_code'],
            token=config.get('auth', {}).get('hf_token', None)
        )
    tokenizer.pad_token = tokenizer.eos_token
    
    # Configure the device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    # Load configuration if it exists
    config_path = os.path.join(model_path, "training_config.json")
    if os.path.exists(config_path):
        with open(config_path, 'r') as f:
            saved_config = json.load(f)
            # Update config with saved values
            config['shared']['model'].update(saved_config['model_config'])
            if 'qlora' in saved_config:
                config['shared']['qlora'] = saved_config['qlora']
            config['shared']['lora'].update(saved_config['lora_config'])
    
    # Load model with appropriate settings
    try:
        if config['shared']['model'].get('use_qlora', False):
            # Configure QLoRA
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=config['shared']['qlora']['load_in_4bit'],
                bnb_4bit_quant_type=config['shared']['qlora']['bnb_4bit_quant_type'],
                bnb_4bit_compute_dtype=getattr(torch, config['shared']['qlora']['bnb_4bit_compute_dtype']),
                bnb_4bit_use_double_quant=config['shared']['qlora']['bnb_4bit_use_double_quant'],
            )
            
            # Load base model with QLoRA
            base_model = AutoModelForCausalLM.from_pretrained(
                config['shared']['model']['model_name'],
                quantization_config=bnb_config,
                trust_remote_code=config['shared']['model']['trust_remote_code'],
                token=config.get('auth', {}).get('hf_token', None)
            )
            
            # Load the trained model from local path
            if os.path.exists(os.path.join(model_path, "adapter_config.json")):
                model = PeftModel.from_pretrained(base_model, model_path, is_trainable=False, device_map="auto")
            else:
                model = base_model
            print("Using QLoRA for model evaluation")
        else:
            # Load model without quantization
            if os.path.exists(os.path.join(model_path, "adapter_config.json")):
                # Load base model first
                base_model = AutoModelForCausalLM.from_pretrained(
                    config['shared']['model']['model_name'],
                    torch_dtype=torch.float16 if config['shared']['model'].get('fp16', False) else torch.float32,
                    trust_remote_code=config['shared']['model']['trust_remote_code'],
                    token=config.get('auth', {}).get('hf_token', None)
                )
                # Then load LoRA adapter
                model = PeftModel.from_pretrained(base_model, model_path, is_trainable=False, device_map="auto")
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    torch_dtype=torch.float16 if config['shared']['model'].get('fp16', False) else torch.float32,
                    trust_remote_code=config['shared']['model']['trust_remote_code'],
                    token=config.get('auth', {}).get('hf_token', None)
                )
            print("Using regular LoRA for model evaluation")
    except Exception as e:
        print(f"Error loading model with quantization: {e}")
        print("Trying to load without quantization...")
        # Try loading the base model first
        base_model = AutoModelForCausalLM.from_pretrained(
            config['shared']['model']['model_name'],
            torch_dtype=torch.float16 if config['shared'].get('fp16', False) else torch.float32,
            trust_remote_code=config['shared']['model']['trust_remote_code'],
            token=config.get('auth', {}).get('hf_token', None)
        )
        # Then try to load LoRA adapter if it exists
        if os.path.exists(os.path.join(model_path, "adapter_config.json")):
            model = PeftModel.from_pretrained(base_model, model_path, is_trainable=False, device_map="auto")
        else:
            model = base_model
    
    # Move model to device
    model = model.to(device)
    
    return model, tokenizer

def load_test_dataset(config):
    """
    Load and prepare test dataset for evaluation.
    
    Args:
        config: Configuration dictionary containing dataset settings
        
    Returns:
        list: Test dataset with specified number of samples
    """
    with open(config['dataset']['test_path'], 'r', encoding='utf-8') as f:
        dataset = json.load(f)
    
    # Use first num_points if specified
    num_points = config['evaluation'].get('num_points', -1)
    if num_points > 0 and num_points < len(dataset):
        dataset = dataset[:num_points]
        print(f"Using first {num_points} samples from test dataset")
    else:
        print(f"Using full test dataset with {len(dataset)} samples")
    
    return dataset

def evaluate_model(model, tokenizer, test_dataset, config):
    """
    Evaluate model using specified metrics.
    
    Args:
        model: Model to evaluate
        tokenizer: Tokenizer for text processing
        test_dataset: Dataset for evaluation
        config: Configuration dictionary containing evaluation settings
        
    Returns:
        dict: Dictionary containing computed metrics
    """
    metrics = {}
    
    # Load reference model for winning rate
    if "win_rate" in config['evaluation']['metrics']:
        reference_model, reference_tokenizer = load_model_and_tokenizer(
            config['reference_model']['model_name'],
            config
        )
    
    # Load annotation model for winning rate and PPR
    if "win_rate" in config['evaluation']['metrics'] or "PPR" in config['evaluation']['metrics']:
        annotation_model, annotation_tokenizer = load_model_and_tokenizer(
            config['annotation_model']['model_name'],
            config,
            is_annotation_model=True  # Pass flag to indicate this is the annotation model
        )
    
    # Compute each requested metric
    for metric in config['evaluation']['metrics']:
        if metric == 'win_rate':
            sampled_win_rate, prob_win_rate = compute_win_rate(
                model, tokenizer,
                reference_model, reference_tokenizer,
                annotation_model, annotation_tokenizer,
                test_dataset,
                batch_size=config['evaluation']['batch_size'],
                max_length=config['evaluation']['max_length'],
                w=config['evaluation']['w']
            )
            metrics['sampled_win_rate'] = sampled_win_rate
            metrics['prob_win_rate'] = prob_win_rate
            print(f"Sampled Winning Rate: {sampled_win_rate:.4f}")
            print(f"Probability Winning Rate: {prob_win_rate:.4f}")
        
        elif metric == 'PPR':
            current_dist, min_ratio = compute_PPR(
                model, tokenizer,
                annotation_model, annotation_tokenizer,
                test_dataset,
                batch_size=config['evaluation']['batch_size'],
                max_length=config['evaluation']['max_length'],
                w=config['evaluation']['w']
            )
            metrics['current_dist'] = current_dist
            metrics['min_ratio'] = min_ratio
            print(f"Current Distribution: {current_dist}")
            print(f"Minimum Ratio: {min_ratio:.4f}")
    
    # Clean up models
    if "win_rate" in config['evaluation']['metrics']:
        del reference_model
        torch.cuda.empty_cache()
    
    if "win_rate" in config['evaluation']['metrics'] or "PPR" in config['evaluation']['metrics']:
        del annotation_model
        torch.cuda.empty_cache()
    
    return metrics

def main():
    # Parse arguments
    parser = argparse.ArgumentParser(description="Evaluate language models")
    parser.add_argument('--config', type=str, default='config_eval.yaml', help='Path to config file')
    args = parser.parse_args()
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load configuration
    config = load_config(args.config)
    
    # Create output directory
    os.makedirs(config['evaluation']['output_dir'], exist_ok=True)
    
    # Initialize wandb
    if 'auth' in config and 'wandb_token' in config['auth']:
        wandb.login(key=config['auth']['wandb_token'])
        wandb.init(
            project=config['tracking']['wandb_project'],
            name=config['tracking']['run_name'],
            config=config
        )
    
    # Load test dataset
    test_dataset = load_test_dataset(config)
    
    # Initialize results dictionary
    results = {}
    
    # Evaluate each model
    for model_name, model_path in config['models'].items():
        print(f"\nEvaluating {model_name} model...")
        
        # Load model and tokenizer
        model, tokenizer = load_model_and_tokenizer(model_path, config)
        
        # Evaluate model
        metrics = evaluate_model(model, tokenizer, test_dataset, config)
        results[model_name] = metrics
        
        # Log metrics to wandb
        if 'auth' in config and 'wandb_token' in config['auth']:
            for metric_name, value in metrics.items():
                wandb.log({
                    f"{model_name}/{metric_name}": value
                })
        
        # Save results after each model
        results_path = os.path.join(config['evaluation']['output_dir'], 'evaluation_results.json')
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to {results_path}")
        
        # Clean up
        del model
        torch.cuda.empty_cache()
    
    if 'auth' in config and 'wandb_token' in config['auth']:
        wandb.finish()

if __name__ == "__main__":
    main()