import logging
import os
import re
import json
import argparse
from concurrent.futures import ThreadPoolExecutor
import threading
from tqdm import tqdm
import numpy as np

from dataset import BaseDataset
from dataset import SNLIDataset, MTBenchDataset, SummEvalDataset
from models import LLMModel, OPENAIModel, QWENModel, QwenLoraModel

LOGGER = logging.getLogger("Uncertainty Align")


class Judger:
    def __init__(self, model: LLMModel, dataset: BaseDataset, threads: int = 1):
        self.model = model
        self.dataset = dataset
        if self.model.is_api:
            self.threads = threads
        else:
            self.threads = 1
        self.lock = threading.Lock()
        
        # Get all possible labels
        self.all_labels = list(self.dataset.label_mapping.keys()) if hasattr(self.dataset, 'label_mapping') else []
        LOGGER.info(f"All possible labels for {self.dataset.name}: {self.all_labels}")

    def evaluate(self, num_samples=None):
        """Evaluate the model's performance on the dataset"""

        total_samples = len(self.dataset)
        if num_samples is None or num_samples > total_samples:
            num_samples = total_samples
        
        LOGGER.info(f"Evaluating {num_samples}/{total_samples} samples from {self.dataset.name}")
        
        tasks = list(range(num_samples))
        with ThreadPoolExecutor(max_workers=self.threads) as executor:
            list(tqdm(executor.map(self.evaluate_one_sample, tasks),
                      total=len(tasks),
                      desc=f"Evaluating {self.model.name} on {self.dataset.name}"))

    def evaluate_one_sample(self, idx):
        data = self.dataset[idx]
        prompt = self.dataset.make_prompt(data)
        
        try:
            prediction, probs = self.model.call(prompt, return_probs=True, do_generation=False)
            prediction = self.dataset.phrase_output(prediction)
            prob_distribution = self._parse_probabilities(probs)
            self.save_prediction(data, prob_distribution)
            
        except Exception as e:
            LOGGER.error(f"Error evaluating item {idx}: {e}")
            prob_distribution = {label.lower(): 1.0/len(self.all_labels) for label in self.all_labels}
            self.save_prediction(data, prob_distribution)

    def _parse_probabilities(self, probs_list):
        EPSILON = 1e-8
        distribution = {label.lower(): EPSILON for label in self.all_labels}
        
        if not probs_list or len(probs_list) == 0:
            print("No probabilities returned, using uniform distribution.")
            return {label.lower(): 1.0/len(self.all_labels) for label in self.all_labels}
        
        probs_dict = probs_list[0]
        
        # Label matching
        for token, prob in probs_dict.items():
            token_lower = token.lower().strip()
            
            for label in self.all_labels:
                label_lower = label.lower()
                
                if token_lower in label_lower or label_lower in token_lower:
                    distribution[label_lower] += prob
                    break
        
        # Normalize probability distribution
        total_prob = sum(distribution.values())
        if total_prob > 0:
            distribution = {k: v/total_prob for k, v in distribution.items()}
        
        return distribution

    def save_prediction(self, data, prediction_distribution):

        model_type = self.model.__class__.__name__.lower().replace('model', '')
        
        file_path = os.path.join(os.path.dirname(__file__), "evaluation_results", "raw", self.dataset.name, model_type)
        os.makedirs(file_path, exist_ok=True)
        filename = f"{self.model.name}.json"
        full_path = os.path.join(file_path, filename)

        with self.lock:  # Thread safety
            if os.path.exists(full_path):
                with open(full_path, "r") as f:
                    file_data = json.load(f)
            else:
                file_data = {}

            id = data[self.dataset.id_key]
            file_data[id] = prediction_distribution

            with open(full_path, "w") as f:
                json.dump(file_data, f, ensure_ascii=False)


def load_datasets(dataset_names):
    """Load specified datasets"""
    datasets = []
    base_path = os.path.dirname(__file__)
    
    for name in dataset_names:
        if name.lower() == "snli":
            datasets.append(SNLIDataset(file_path=os.path.join(base_path, "dataset", "test", "snli_test.jsonl"), name="snli"))
        elif name.lower() == "multinli":
            datasets.append(SNLIDataset(file_path=os.path.join(base_path, "dataset", "test", "multinli_test.jsonl"), name="multinli"))
        elif name.lower() == "mtbench":
            datasets.append(MTBenchDataset(file_path=os.path.join(base_path, "dataset", "test", "mt_bench_test.jsonl"), name="mtbench"))
        elif name.lower() == "summeval":
            datasets.append(SummEvalDataset(file_path=os.path.join(base_path, "dataset", "test", "summeval_test.jsonl"), name="summeval"))
        else:
            LOGGER.warning(f"Unknown dataset name: {name}")
    
    return datasets


def load_openai_models(model_names):
    """Load OpenAI models"""
    return [OPENAIModel(model_name=name) for name in model_names]


def load_qwen_models(model_names, model_path=None):
    """Load Qwen models"""
    return [QWENModel(model_name=name, model_path=model_path) for name in model_names]


def load_qwen_lora_models(params):
    """
    Load QwenLora models
    
    Args:
        params: List of dictionaries containing the following keys:
            - dataset: Dataset name
            - epsilon: Perturbation magnitude
            - alpha: KL loss weight
            - checkpoint_name: Checkpoint name, automatically search all checkpoints if None
    """
    models = []
    base_path = os.path.dirname(__file__)
    
    for param in params:
        dataset = param.get('dataset', '').lower()
        epsilon = param.get('epsilon', '')
        alpha = param.get('alpha', '')
        qwen_model_name = param.get('model_name', None)

        if dataset == '' or epsilon == '' or alpha == '' or qwen_model_name is None:
            LOGGER.warning("Dataset, epsilon, and alpha must be specified for QwenLora models. Model name is also required.")
            continue
        
        epsilon_formatted = epsilon.replace('.', '_')
        alpha_formatted = alpha.replace('.', '_')
        checkpoint_dir = os.path.join(
            base_path, 
            "outputs", 
            qwen_model_name,
            dataset,
            f"eps_{epsilon_formatted}",
            f"alpha_{alpha_formatted}"
        )
        
        checkpoint_name = param.get('checkpoint_name', None)
        
        if checkpoint_name is None:
            try:
                items = os.listdir(checkpoint_dir)
                checkpoint_dirs = []
                for item in items:
                    if os.path.isdir(os.path.join(checkpoint_dir, item)) and re.match(r'checkpoint-\d+', item):
                        checkpoint_dirs.append(item)
                
                if not checkpoint_dirs and os.path.isdir(os.path.join(checkpoint_dir, "final_checkpoint")):
                    checkpoint_dirs = ["final_checkpoint"]
                
                if not checkpoint_dirs:
                    LOGGER.warning(f"No checkpoints found in {checkpoint_dir}")
                    continue
                
                checkpoint_dirs.sort(key=lambda x: int(x.split('-')[1]) if '-' in x and x.split('-')[1].isdigit() else float('inf'))
                
                for i, ckpt in enumerate(checkpoint_dirs):
                    epoch = i + 1
                    model_name = f"{qwen_model_name}_eps_{epsilon}_alpha_{alpha}_epoch_{epoch}"
                    model = QwenLoraModel(
                        model_name=model_name,
                        model_path=param.get('model_path', None),
                        lora_checkpoint_path=checkpoint_dir,
                        checkpoint_name=ckpt
                    )
                    models.append(model)
                
            except Exception as e:
                LOGGER.error(f"Error searching checkpoints in {checkpoint_dir}: {e}")
                continue
        else:
            model_name = f"{qwen_model_name}_eps_{epsilon}_alpha_{alpha}"
            model = QwenLoraModel(
                model_name=model_name,
                model_path=param.get('model_path', None),
                lora_checkpoint_path=checkpoint_dir,
                checkpoint_name=checkpoint_name
            )
            models.append(model)
    
    return models


def main():
    parser = argparse.ArgumentParser(description="Evaluate model performance on datasets")
    
    parser.add_argument('--model_type', type=str, required=True, choices=['openai', 'qwen', 'qwen_lora'], help='Model type to evaluate')
    parser.add_argument('--datasets', type=str, nargs='+', required=True,
                        choices=['snli', 'multinli', 'mtbench', 'summeval', 'SNLI', 'MultiNLI', 'MTBench', 'SummEval'],
                        help='Dataset names to evaluate')
    parser.add_argument('--model_names', type=str, nargs='+', help='Model names (for openai and qwen models)')
    parser.add_argument('--model_path', type=str, help='Model path')
    parser.add_argument('--num_samples', type=int, default=None, 
                        help='Number of samples to evaluate per dataset, default is to evaluate the entire dataset')
    parser.add_argument('--threads', type=int, default=4,
                        help='Number of threads used for evaluation (only OpenAI models support multi-threading)')
    
    # QwenLora specific parameters
    parser.add_argument('--dataset', type=str, 
                        help='Dataset name for QwenLora model training')
    parser.add_argument('--model_name', type=str, 
                        help='Model name for QwenLora model training')
    parser.add_argument('--epsilon_values', type=str, nargs='+',
                        help='List of perturbation sizes for QwenLora models')
    parser.add_argument('--alpha_values', type=str, nargs='+',
                        help='List of KL weights for QwenLora models')
    parser.add_argument('--checkpoint_name', type=str, default=None,
                        help='Checkpoint name for QwenLora models, default is all checkpoints in the directory')
    
    args = parser.parse_args()
    if len(args.epsilon_values) == 1:
        args.epsilon_values = args.epsilon_values[0].split(' ')
    if len(args.alpha_values) == 1:
        args.alpha_values = args.alpha_values[0].split(' ')
    datasets = load_datasets(args.datasets)
    models = []

    # Load models based on model type
    if args.model_type == 'openai':
        if not args.model_names:
            raise ValueError("Model names must be provided for OpenAI models")
        models = load_openai_models(args.model_names)
    elif args.model_type == "qwen":
        if not args.model_names:
            raise ValueError("Model names must be provided for Qwen models")
        models = load_qwen_models(args.model_names, args.model_path)
    elif args.model_type == 'qwen_lora':
        if not args.epsilon_values or not args.alpha_values:
            parser.error("QwenLora models require epsilon_values and alpha_values parameters")
        qwen_lora_params = []
        for epsilon in args.epsilon_values:
            for alpha in args.alpha_values:
                qwen_lora_params.append({
                    'dataset': args.dataset or args.datasets[0],
                    'epsilon': epsilon,
                    'alpha': alpha,
                    'checkpoint_name': args.checkpoint_name,
                    'model_path': args.model_path,
                    'model_name': args.model_name
                })
        models = load_qwen_lora_models(qwen_lora_params)
    
    for model in models:
        for dataset in datasets:
            LOGGER.info(f"Evaluating {model.name} on {dataset.name}")
            judger = Judger(model, dataset, threads=args.threads)
            judger.evaluate(args.num_samples)
            LOGGER.info(f"Completed evaluating {model.name} on {dataset.name}")


if __name__ == "__main__":
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    console_handler.setFormatter(formatter)
    LOGGER.addHandler(console_handler)
    LOGGER.setLevel(logging.INFO)

    main()
