from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from .prompt_builder import PromptBuilder
from .utils import compute_f1, compute_exact_match, extract_final_number
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
import logging
import os
import json
from transformers import BitsAndBytesConfig


logger = logging.getLogger(__name__)

class BenchmarkEvaluator:
    _instance = None
    _model = None
    _tokenizer = None
    
    def __new__(cls, model_name: str = None, cache_dir: str = '/data/'):
        if cls._instance is None:
            cls._instance = super(BenchmarkEvaluator, cls).__new__(cls)
            cls._instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            if model_name:
                cls._instance.setup_model(model_name, cache_dir)
        return cls._instance
    
    def __init__(self, model_name: str = None, cache_dir: str = '/data/'):

        self.model_name = model_name
        self.cache_dir = cache_dir
        self.prompt_builder = PromptBuilder()
    
    def setup_model(self, model_name: str, cache_dir: str):
        if self._model is None:
            logger.info(f"Loading model: {model_name}")
            
            self._tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                cache_dir=cache_dir,
                padding_side="left",
                trust_remote_code=True
            )
            
            if self._tokenizer.pad_token is None:
                self._tokenizer.pad_token = self._tokenizer.eos_token
            
            try:
                self._model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    cache_dir=cache_dir,
                    device_map="auto",
                    torch_dtype=torch.float16,
                    trust_remote_code=True
                )
            except RuntimeError as e: 
                logger.warning("FP16.")
                
                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.float16,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_use_double_quant=False
                )
                
                self._model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    cache_dir=cache_dir,
                    device_map="auto",
                    quantization_config=quantization_config,
                    torch_dtype=torch.float16,
                    trust_remote_code=True
                )
            
            self._model.eval()
    
    @property
    def model(self):
        return self._model
    
    @property
    def tokenizer(self):
        return self._tokenizer

    @torch.no_grad()
    def evaluate_mmlu(self, dataset) -> Dict[str, float]:
        results = {}
        subjects = dataset.unique('subject')
        
        for subject in subjects:
            subject_data = dataset.filter(lambda x: x['subject'] == subject)
            correct = 0
            total = 0
            
            few_shot_examples = subject_data.select(range(5))
            
            for item in tqdm(subject_data):
                prompt = self.prompt_builder.construct_mmlu_prompt(item, few_shot_examples)
                inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
                
                outputs = self.model(**inputs)
                logits = outputs.logits[:, -1, :]
                
                option_tokens = self.tokenizer.convert_tokens_to_ids(['A', 'B', 'C', 'D'])
                option_logits = logits[:, option_tokens]
                pred = torch.argmax(option_logits).item()
                
                if pred == item['answer']:
                    correct += 1
                total += 1
            
            results[subject] = (correct / total) * 100
        
        results['average'] = np.mean(list(results.values()))
        return results

    @torch.no_grad()
    def evaluate_hellaswag(self, dataset) -> float:
        correct = 0
        total = 0
        
        for item in tqdm(dataset):
            scores = []
            context = f"{item['ctx_a']} {item['ctx_b']}"
            
            for ending in item['endings']:
                full_sequence = context + " " + ending
                inputs = self.tokenizer(full_sequence, return_tensors="pt").to(self.device)
                
                outputs = self.model(**inputs)
                logits = outputs.logits
                
                log_probs = torch.log_softmax(logits, dim=-1)
                token_probs = torch.gather(log_probs[:, :-1, :], 2, inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)
                
                context_tokens = self.tokenizer(context, return_tensors="pt")['input_ids'].shape[1]
                sequence_score = token_probs[0, context_tokens-1:].mean().item()
                scores.append(sequence_score)
            
            pred = np.argmax(scores)
            correct_label = int(item['label']) if isinstance(item['label'], str) else item['label']
            
            if pred == correct_label:
                correct += 1
            total += 1
            

        accuracy = (correct / total) * 100
        print(f"\nAcc: {accuracy:.2f}%")
        return accuracy

    @torch.no_grad()
    def evaluate_arc(self, dataset) -> float:
        correct = 0
        total = 0
        
        few_shot_examples = dataset.select(range(25))
        
        for item in tqdm(dataset):
            prompt = self.prompt_builder.construct_arc_prompt(item, few_shot_examples)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :]
            
            option_tokens = self.tokenizer.convert_tokens_to_ids(['A', 'B', 'C', 'D', 'E'])
            option_logits = logits[:, option_tokens]
            pred = torch.argmax(option_logits).item()
            
            if pred == ord(item['answerKey']) - ord('A'):
                correct += 1
            total += 1
        
        return (correct / total) * 100

    @torch.no_grad()
    def evaluate_gsm8k(self, dataset) -> float:
        correct = 0
        total = 0
        
        few_shot_examples = dataset.select(range(5))
        
        for item in tqdm(dataset):
            prompt = self.prompt_builder.construct_gsm8k_prompt(item, few_shot_examples)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                num_beams=1,
                temperature=0.0,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
            
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            answer = extract_final_number(generated_text)
            
            if answer is not None and abs(answer - float(item['answer'])) < 1e-6:
                correct += 1
            total += 1
        
        return (correct / total) * 100

    @torch.no_grad()
    def evaluate_math(self, dataset) -> float:
        correct = 0
        total = 0
        
        few_shot_examples = dataset.select(range(5))  # optional
        
        for item in tqdm(dataset):
            prompt = self.prompt_builder.construct_math_prompt(item, few_shot_examples)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                temperature=0.0,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            answer = extract_final_number(generated_text)

            if answer is not None and abs(answer - float(item['solution'])) < 1e-6:
                correct += 1
            total += 1

        return (correct / total) * 100


    @torch.no_grad()
    def evaluate_piqa(self, dataset) -> float:
        correct = 0
        total = 0
        
        for item in tqdm(dataset):
            goal = item['goal']
            sol1, sol2 = item['sol1'], item['sol2']
            
            inputs1 = self.tokenizer(f"{goal}\nAnswer: {sol1}", return_tensors="pt").to(self.device)
            inputs2 = self.tokenizer(f"{goal}\nAnswer: {sol2}", return_tensors="pt").to(self.device)
            
            score1 = self.model(**inputs1).logits[:, -1, :].max().item()
            score2 = self.model(**inputs2).logits[:, -1, :].max().item()
            
            pred = 0 if score1 > score2 else 1
            if pred == item['label']:
                correct += 1
            total += 1
        
        return (correct / total) * 100

    @torch.no_grad()
    def evaluate_winogrande(self, dataset) -> float:
        correct = 0
        total = 0
        
        few_shot_examples = dataset.select(range(5))
        
        for item in tqdm(dataset):
            prompt = self.prompt_builder.construct_winogrande_prompt(item, few_shot_examples)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :]
            
            option_tokens = self.tokenizer.convert_tokens_to_ids(['1', '2'])
            option_logits = logits[:, option_tokens]
            pred = torch.argmax(option_logits).item() + 1
            
            if pred == int(item['answer']):
                correct += 1
            total += 1
        
        return (correct / total) * 100

    @torch.no_grad()
    def evaluate_boolq(self, dataset) -> float:
        correct = 0
        total = 0
        
        few_shot_examples = dataset.select(range(5))
        
        for item in tqdm(dataset):
            prompt = self.prompt_builder.construct_boolq_prompt(item, few_shot_examples)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :]
            
            yes_token = self.tokenizer.encode(" Yes", add_special_tokens=False)[-1]
            no_token = self.tokenizer.encode(" No", add_special_tokens=False)[-1]
            
            yes_prob = torch.softmax(logits[:, [yes_token, no_token]], dim=-1)[0, 0].item()
            no_prob = torch.softmax(logits[:, [yes_token, no_token]], dim=-1)[0, 1].item()
            
            predicted = "Yes" if yes_prob > no_prob else "No"
            true_answer = "Yes" if item['answer'] else "No"
            
            if total < 10:
                print(f"\n=== Example {total} ===")
                print(f"Passage: {item['passage'][:100]}...")
                print(f"Question: {item['question']}")
                print(f"True Answer: {true_answer}")
                print(f"Predicted: {predicted}")
                print(f"Yes Prob: {yes_prob:.4f}, No Prob: {no_prob:.4f}")
            
            if predicted.lower() == true_answer.lower():
                correct += 1
                if total < 10:
                    print("Correct: True")
            else:
                if total < 10:
                    print("Correct: False")
            
            total += 1
        
        return (correct / total) * 100

    @torch.no_grad()
    def evaluate_truthfulqa(self, dataset) -> Dict[str, float]:

        mc1_correct = 0
        mc2_correct = 0
        total = 0
        
        for item in tqdm(dataset):
            mc1_prompt = f"Q: {item['question']}\n"
            for idx, choice in enumerate(item['mc1_targets']):
                mc1_prompt += f"{chr(65+idx)}) {choice}\n"
            mc1_prompt += "A:"
            
            inputs = self.tokenizer(mc1_prompt, return_tensors="pt").to(self.device)
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :]
            
            option_tokens = self.tokenizer.convert_tokens_to_ids([chr(65+i) for i in range(len(item['mc1_targets']))])
            option_logits = logits[:, option_tokens]
            mc1_pred = torch.argmax(option_logits).item()
            
            if mc1_pred == item['mc1_labels'].index(1):
                mc1_correct += 1
            
            mc2_prompt = f"Q: {item['question']}\n"
            for idx, choice in enumerate(item['mc2_targets']):
                mc2_prompt += f"{chr(65+idx)}) {choice}\n"
            mc2_prompt += "A:"
            
            inputs = self.tokenizer(mc2_prompt, return_tensors="pt").to(self.device)
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :]
            
            option_tokens = self.tokenizer.convert_tokens_to_ids([chr(65+i) for i in range(len(item['mc2_targets']))])
            option_logits = logits[:, option_tokens]
            mc2_pred = torch.argmax(option_logits).item()
            
            if mc2_pred == item['mc2_labels'].index(1):
                mc2_correct += 1
            
            total += 1
        
        return {
            "mc1_accuracy": (mc1_correct / total) * 100,
            "mc2_accuracy": (mc2_correct / total) * 100
        }

    @torch.no_grad()
    def evaluate_openbookqa(self, dataset) -> float:
        correct = 0
        total = 0
        
        few_shot_examples = dataset.select(range(5))
        
        for item in tqdm(dataset):
            prompt = self.prompt_builder.construct_openbookqa_prompt(item, few_shot_examples)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :]
            
            option_tokens = self.tokenizer.convert_tokens_to_ids(['A', 'B', 'C', 'D'])
            option_logits = logits[:, option_tokens]
            pred = torch.argmax(option_logits).item()
            
            if pred == ord(item['answerKey']) - ord('A'):
                correct += 1
            total += 1
        
        return (correct / total) * 100

    @torch.no_grad()
    def evaluate_squad_v2(self, dataset) -> Dict[str, float]:
        exact_matches = 0
        f1_scores = []
        total = 0
        
        few_shot_examples = dataset.select(range(5))
        
        for item in tqdm(dataset):
            prompt = self.prompt_builder.construct_squad_prompt(item, few_shot_examples)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=50,
                num_beams=4,
                early_stopping=True,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
            
            predicted_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            predicted_answer = predicted_answer[len(self.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)):].strip()
            
            if not item['answers']['text']:
                exact_match = predicted_answer.lower().strip() == "no answer"
                f1 = 1.0 if exact_match else 0.0
            else:
                exact_match = compute_exact_match(predicted_answer, item['answers']['text'][0])
                f1 = compute_f1(predicted_answer, item['answers']['text'][0])
            
            if exact_match:
                exact_matches += 1
            f1_scores.append(f1)
            total += 1
        
        return {
            "exact_match": (exact_matches / total) * 100,
            "f1": (sum(f1_scores) / total) * 100
        }


def get_available_benchmarks(verbose=True):
    BENCHMARKS = {
    "mmlu": {
        "dataset": "cais/mmlu",
        "config": "all",
        "split": "test",
        "metrics": ["accuracy"],
        "num_few_shot": 5
    },
    "hellaswag": {
        "dataset": "hellaswag", 
        "split": "validation",
        "metrics": ["accuracy"],
        "num_few_shot": 0
    },
    "truthfulqa": {
        "dataset": "truthful_qa",
        "split": "validation",
        "metrics": ["accuracy", "truth_score"],
        "num_few_shot": 0
    },
    "arc": {
        "dataset": "ai2_arc",
        "split": "test",
        "config": "ARC-Challenge",
        "metrics": ["accuracy"],
        "num_few_shot": 25
    },
    "gsm8k": {
        "dataset": "gsm8k",
        "split": "test",
        "config": "main",
        "metrics": ["accuracy"],
        "num_few_shot": 5  
    },
    "piqa": {
        "dataset": "piqa",
        "split": "validation",
        "metrics": ["accuracy"],
        "num_few_shot": 0  # zero-shot
    },
    "winogrande": {
        "dataset": "winogrande",
        "split": "validation",
        "config": "winogrande_xl",
        "metrics": ["accuracy"],
        "num_few_shot": 5  
    },
    "boolq": {
        "dataset": "boolq",
        "split": "validation",
        "metrics": ["accuracy"],
        "num_few_shot": 5  
    },
    "openbookqa": {
        "dataset": "openbookqa",
        "split": "test",
        "metrics": ["accuracy"],
        "num_few_shot": 5  
    },
    "squad_v2": {
        "dataset": "squad_v2",
        "split": "validation",
        "metrics": ["exact_match", "f1"],
        "num_few_shot": 5  
    }
    }
    if verbose:
        for name in BENCHMARKS.keys():
            print(f"- {name}")
    return BENCHMARKS
    
       
def run_benchmarks(
    model_name: str,
    selected_benchmarks: List[str],
    cache_dir: str = '/data/'
):
    BENCHMARKS = get_available_benchmarks(verbose=False)
    
    evaluator = BenchmarkEvaluator(model_name, cache_dir)
    results = {}
    
    for benchmark in selected_benchmarks:
        logger.info(f"\nEvaluating {benchmark}...")
        
        if "config" in BENCHMARKS[benchmark]:
            dataset = load_dataset(
                BENCHMARKS[benchmark]["dataset"],
                BENCHMARKS[benchmark]["config"],
                split=BENCHMARKS[benchmark]["split"],
                cache_dir=cache_dir
            )
        else:
            dataset = load_dataset(
                BENCHMARKS[benchmark]["dataset"],
                split=BENCHMARKS[benchmark]["split"],
                cache_dir=cache_dir
            )
        
        if benchmark == "mmlu":
            results[benchmark] = evaluator.evaluate_mmlu(dataset)
        elif benchmark == "hellaswag":
            results[benchmark] = evaluator.evaluate_hellaswag(dataset)
        elif benchmark == "truthfulqa":
            results[benchmark] = evaluator.evaluate_truthfulqa(dataset)
        elif benchmark == "arc":
            results[benchmark] = evaluator.evaluate_arc(dataset)
        elif benchmark == "gsm8k":
            results[benchmark] = evaluator.evaluate_gsm8k(dataset)
        elif benchmark == "piqa":
            results[benchmark] = evaluator.evaluate_piqa(dataset)
        elif benchmark == "winogrande":
            results[benchmark] = evaluator.evaluate_winogrande(dataset)
        elif benchmark == "boolq":
            results[benchmark] = evaluator.evaluate_boolq(dataset)
        elif benchmark == "openbookqa":
            results[benchmark] = evaluator.evaluate_openbookqa(dataset)
        elif benchmark == "squad_v2":
            results[benchmark] = evaluator.evaluate_squad_v2(dataset)
        else:
            logger.warning(f"Unknown benchmark: {benchmark}")
            continue
        
        logger.info(f"Completed {benchmark} evaluation")
        
        torch.cuda.empty_cache()
    
    return results