import warnings
import os
import logging
import json
import yaml
import torch
import wandb
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from utils.compute_metrics import compute_win_rate, compute_PPR

# Configure logging to suppress warnings
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("peft").setLevel(logging.ERROR)

# Suppress specific warnings
warnings.filterwarnings("ignore", message=".*tokenizer.*deprecated.*")
warnings.filterwarnings("ignore", module="transformers.*")
warnings.filterwarnings("ignore", module="peft.*")

# Disable all warnings
warnings.simplefilter("ignore")

def load_config(config_path):
    """
    Load configuration from YAML file.
    
    Args:
        config_path: Path to the YAML configuration file
        
    Returns:
        dict: Configuration dictionary
    """
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def load_model_and_tokenizer(model_path, config, is_annotation_model=False):
    """
    Load model and tokenizer from specified path with appropriate settings.
    
    Args:
        model_path: Path to the model directory or model name
        config: Configuration dictionary containing model settings
        is_annotation_model: Whether this is an annotation model (affects tokenizer selection)
        
    Returns:
        tuple: (model, tokenizer) loaded and configured according to settings
    """
    # Load tokenizer from base model
    if is_annotation_model:
        # For annotation model, use its specific tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            config['annotation_model']['model_name'],
            trust_remote_code=config['shared']['model']['trust_remote_code'],
            token=config.get('auth', {}).get('hf_token', None)
        )
    else:
        # For main model, use the shared model's tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            config['shared']['model']['model_name'],
            trust_remote_code=config['shared']['model']['trust_remote_code'],
            token=config.get('auth', {}).get('hf_token', None)
        )
    tokenizer.pad_token = tokenizer.eos_token
    
    # Configure the device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    # Load configuration if it exists
    config_path = os.path.join(model_path, "training_config.json")
    if os.path.exists(config_path):
        with open(config_path, 'r') as f:
            saved_config = json.load(f)
            # Update config with saved values
            config['shared']['model'].update(saved_config['model_config'])
            if 'qlora' in saved_config:
                config['shared']['qlora'] = saved_config['qlora']
            config['shared']['lora'].update(saved_config['lora_config'])
    
    # Load model with appropriate settings
    try:
        if config['shared']['model'].get('use_qlora', False):
            # Configure QLoRA
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=config['shared']['qlora']['load_in_4bit'],
                bnb_4bit_quant_type=config['shared']['qlora']['bnb_4bit_quant_type'],
                bnb_4bit_compute_dtype=getattr(torch, config['shared']['qlora']['bnb_4bit_compute_dtype']),
                bnb_4bit_use_double_quant=config['shared']['qlora']['bnb_4bit_use_double_quant'],
            )
            
            # Load base model with QLoRA
            base_model = AutoModelForCausalLM.from_pretrained(
                config['shared']['model']['model_name'],
                quantization_config=bnb_config,
                trust_remote_code=config['shared']['model']['trust_remote_code'],
                token=config.get('auth', {}).get('hf_token', None)
            )
            
            # Load the trained model from local path
            if os.path.exists(os.path.join(model_path, "adapter_config.json")):
                model = PeftModel.from_pretrained(base_model, model_path, is_trainable=False, device_map="auto")
            else:
                model = base_model
            print("Using QLoRA for model evaluation")
        else:
            # Load model without quantization
            if os.path.exists(os.path.join(model_path, "adapter_config.json")):
                # Load base model first
                base_model = AutoModelForCausalLM.from_pretrained(
                    config['shared']['model']['model_name'],
                    torch_dtype=torch.float16 if config['shared']['model'].get('fp16', False) else torch.float32,
                    trust_remote_code=config['shared']['model']['trust_remote_code'],
                    token=config.get('auth', {}).get('hf_token', None)
                )
                # Then load LoRA adapter
                model = PeftModel.from_pretrained(base_model, model_path, is_trainable=False, device_map="auto")
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    torch_dtype=torch.float16 if config['shared']['model'].get('fp16', False) else torch.float32,
                    trust_remote_code=config['shared']['model']['trust_remote_code'],
                    token=config.get('auth', {}).get('hf_token', None)
                )
            print("Using regular LoRA for model evaluation")
    except Exception as e:
        print(f"Error loading model with quantization: {e}")
        print("Trying to load without quantization...")
        # Try loading the base model first
        base_model = AutoModelForCausalLM.from_pretrained(
            config['shared']['model']['model_name'],
            torch_dtype=torch.float16 if config['shared'].get('fp16', False) else torch.float32,
            trust_remote_code=config['shared']['model']['trust_remote_code'],
            token=config.get('auth', {}).get('hf_token', None)
        )
        # Then try to load LoRA adapter if it exists
        if os.path.exists(os.path.join(model_path, "adapter_config.json")):
            model = PeftModel.from_pretrained(base_model, model_path, is_trainable=False, device_map="auto")
        else:
            model = base_model
    
    # Move model to device
    model = model.to(device)
    
    return model, tokenizer

def load_test_dataset(config):
    """
    Load and prepare test dataset for evaluation.
    
    Args:
        config: Configuration dictionary containing dataset settings
        
    Returns:
        list: Test dataset with specified number of samples
    """
    with open(config['dataset']['test_path'], 'r', encoding='utf-8') as f:
        dataset = json.load(f)
    
    # Use first num_points if specified
    num_points = config['evaluation'].get('num_points', -1)
    if num_points > 0 and num_points < len(dataset):
        dataset = dataset[:num_points]
        print(f"Using first {num_points} samples from test dataset")
    else:
        print(f"Using full test dataset with {len(dataset)} samples")
    
    return dataset

def evaluate_model(model, tokenizer, test_dataset, config):
    """
    Evaluate model using specified metrics.
    
    Args:
        model: Model to evaluate
        tokenizer: Tokenizer for text processing
        test_dataset: Dataset for evaluation
        config: Configuration dictionary containing evaluation settings
        
    Returns:
        dict: Dictionary containing computed metrics
    """
    metrics = {}
    
    # Load reference model for winning rate
    if "win_rate" in config['evaluation']['metrics']:
        reference_model, reference_tokenizer = load_model_and_tokenizer(
            config['reference_model']['model_name'],
            config
        )
    
    # Load annotation model for winning rate and PPR
    if "win_rate" in config['evaluation']['metrics'] or "PPR" in config['evaluation']['metrics']:
        annotation_model, annotation_tokenizer = load_model_and_tokenizer(
            config['annotation_model']['model_name'],
            config,
            is_annotation_model=True  # Pass flag to indicate this is the annotation model
        )
    
    # Compute each requested metric
    for metric in config['evaluation']['metrics']:
        if metric == 'win_rate':
            sampled_win_rate, prob_win_rate = compute_win_rate(
                model, tokenizer,
                reference_model, reference_tokenizer,
                annotation_model, annotation_tokenizer,
                test_dataset,
                batch_size=config['evaluation']['batch_size'],
                max_length=config['evaluation']['max_length'],
                w=config['evaluation']['w']
            )
            metrics['sampled_win_rate'] = sampled_win_rate
            metrics['prob_win_rate'] = prob_win_rate
            print(f"Sampled Winning Rate: {sampled_win_rate:.4f}")
            print(f"Probability Winning Rate: {prob_win_rate:.4f}")
        
        elif metric == 'PPR':
            current_dist, min_ratio = compute_PPR(
                model, tokenizer,
                annotation_model, annotation_tokenizer,
                test_dataset,
                batch_size=config['evaluation']['batch_size'],
                max_length=config['evaluation']['max_length'],
                w=config['evaluation']['w']
            )
            metrics['current_dist'] = current_dist
            metrics['min_ratio'] = min_ratio
            print(f"Current Distribution: {current_dist}")
            print(f"Minimum Ratio: {min_ratio:.4f}")
    
    # Clean up models
    if "win_rate" in config['evaluation']['metrics']:
        del reference_model
        torch.cuda.empty_cache()
    
    if "win_rate" in config['evaluation']['metrics'] or "PPR" in config['evaluation']['metrics']:
        del annotation_model
        torch.cuda.empty_cache()
    
    return metrics

def main():
    # Parse arguments
    parser = argparse.ArgumentParser(description="Evaluate language models")
    parser.add_argument('--config', type=str, default='config_eval.yaml', help='Path to config file')
    args = parser.parse_args()
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load configuration
    config = load_config(args.config)
    
    # Create output directory
    os.makedirs(config['evaluation']['output_dir'], exist_ok=True)
    
    # Initialize wandb
    if 'auth' in config and 'wandb_token' in config['auth']:
        wandb.login(key=config['auth']['wandb_token'])
        wandb.init(
            project=config['tracking']['wandb_project'],
            name=config['tracking']['run_name'],
            config=config
        )
    
    # Load test dataset
    test_dataset = load_test_dataset(config)
    
    # Initialize results dictionary
    results = {}
    
    # Evaluate each model
    for model_name, model_path in config['models'].items():
        print(f"\nEvaluating {model_name} model...")
        
        # Load model and tokenizer
        model, tokenizer = load_model_and_tokenizer(model_path, config)
        
        # Evaluate model
        metrics = evaluate_model(model, tokenizer, test_dataset, config)
        results[model_name] = metrics
        
        # Log metrics to wandb
        if 'auth' in config and 'wandb_token' in config['auth']:
            for metric_name, value in metrics.items():
                wandb.log({
                    f"{model_name}/{metric_name}": value
                })
        
        # Save results after each model
        results_path = os.path.join(config['evaluation']['output_dir'], 'evaluation_results.json')
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to {results_path}")
        
        # Clean up
        del model
        torch.cuda.empty_cache()
    
    if 'auth' in config and 'wandb_token' in config['auth']:
        wandb.finish()

if __name__ == "__main__":
    main()