import fnmatch
import torch
import numpy as np
from tqdm import tqdm
from torch import nn
from datasets import load_dataset, load_from_disk

from .data import get_loaders, get_loaders_sample

def eval_ppl(model, tokenizer, device=torch.device("cuda")):
    # Set dataset
    dataset = "wikitext2"
    # dataset = "c4"
    # dataset = "ptb"

    # Print status
    print(f"evaluating on {dataset}")

    # Get the test loader
    _, testloader = get_loaders_sample(
        dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer
    )

    # Evaluate ppl in no grad context to avoid updating the model
    with torch.no_grad():
        ppl_test = eval_ppl_wikitext(model, testloader, 1, device)
    return ppl_test 

# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_wikitext(model, testenc, bs=1, device=None):
    # Get input IDs
    testenc = testenc.input_ids

    # Calculate number of samples
    nsamples = testenc.numel() // model.seqlen

    # List to store negative log likelihoods
    nlls = []
    print(f"nsamples {nsamples}")

    # Loop through each batch
    for i in range(0,nsamples,bs):
        if i % 50 == 0:
            print(f"sample {i}")

        # Calculate end index
        j = min(i+bs, nsamples)

        # Prepare inputs and move to device
        inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j-i, model.seqlen)

        # Forward pass through the model

        lm_logits = model(inputs).logits

        # Shift logits and labels for next token prediction
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        # Compute loss
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))

        # Calculate negative log likelihood
        neg_log_likelihood = loss.float() * model.seqlen * (j-i)

        # Append to list of negative log likelihoods
        nlls.append(neg_log_likelihood)

    # Compute perplexity
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
    
    # Empty CUDA cache to save memory
    torch.cuda.empty_cache()

    return ppl.item()

def PPLMetric(model, tokenizer, datasets, seq_len=128, batch_size=1, device="cuda", out_logits=False, target=None):
    metric = {}
    for dataset in datasets:
        if dataset == 'ptb':
            _, test_loader = get_loaders(dataset, tokenizer, seq_len=256, batch_size=batch_size)
        else:
            _, test_loader = get_loaders(dataset, tokenizer, seq_len=seq_len, batch_size=batch_size)
        ppl, logits = llama_eval(model, test_loader, device, out_logits=out_logits, target=target)
        metric[dataset] = ppl
    print(metric)
    return metric, logits

@torch.no_grad()
def llama_eval(model, test_lodaer, device, out_logits=False, target=None):
    nlls = []
    n = 0
    logits_lst = []
    kl = []
    for batch in tqdm(test_lodaer):
        batch = batch.to(device)
        output = model(batch)
        lm_logits = output.logits  # [batch, seqlen, vab]

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        nlls.append(loss)
        if target:
            # p_tec = torch.softmax(lm_logits.view(-1, lm_logits.shape[2]), dim=-1)
            # q = torch.log_softmax(target[n].view(-1, lm_logits.shape[2]).to(device), dim=-1)
            p_tec = torch.softmax(lm_logits, dim=-1)
            q = torch.log_softmax(target[n].to(device), dim=-1)
            kl_loss = torch.nn.functional.kl_div(q, p_tec, reduction='batchmean')

            n += 1
            kl.append(kl_loss.item())
            # print(len(target))
        if out_logits:
            logits_lst.append(lm_logits.to('cpu'))
    # print(torch.cat(nlls, dim=-1).mean())
    ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()).item()
    if target:
        kl_loss = sum(kl) / len(kl)
        ppl = ppl + 1e-3 * kl_loss

    return ppl, logits_lst

def evaluate_model(model, tokenizer, dataset_name):
    # dataset = load_dataset("./ai2_arc", "ARC-Easy", split='train')
    # dataset = load_from_disk("./ai2_arc")
    testdata = load_dataset("parquet", data_files={"test":'ai2_arc/ARC-Easy/test-00000-of-00001.parquet'}, split='test')

    correct = 0
    for example in tqdm(testdata):
        question = example['question']
        choices = example['choices']['text']
        correct_answer = example['answerKey']
        
        # 对每个选项生成预测
        logits = []
        for choice in choices:
            inputs = tokenizer.encode(question + " " + choice, return_tensors="pt").to('cuda')
            outputs = model(inputs, labels=inputs)
            logits.append(outputs.loss.item())
        
        # 选择最低 loss 的选项（因为用的是 negative log likelihood）
        predicted_answer = chr(ord('A') + logits.index(min(logits)))
        if predicted_answer == correct_answer:
            correct += 1
    
    accuracy = correct / len(testdata)
    return accuracy

def eval_zero_shot(model_name, model, tokenizer, task_list=["boolq","rte","hellaswag","winogrande","arc_challenge","arc_easy","openbookqa"], 
        num_fewshot=0, use_accelerate=False, add_special_tokens=False):
    from lm_eval import tasks, evaluator 
    def pattern_match(patterns, source_list):
        task_names = set()
        for pattern in patterns:
            for matching in fnmatch.filter(source_list, pattern):
                task_names.add(matching)
        return list(task_names)
    task_names = pattern_match(task_list, tasks.ALL_TASKS)
    model_args = f"pretrained={model_name},cache_dir=./llm_weights"
    limit = None 
    if "70b" in model_name or "65b" in model_name:
        limit = 2000
    if use_accelerate:
        model_args = f"pretrained={model_name},cache_dir=./llm_weights,use_accelerate=True"
    results = evaluator.simple_evaluate(
        model="hf-causal-experimental",
        model_args=model_args,
        tasks=task_names,
        num_fewshot=num_fewshot,
        batch_size=None,
        device=None,
        no_cache=True,
        limit=limit,
        description_dict={},
        decontamination_ngrams_path=None,
        check_integrity=False,
        pretrained_model=model,
        tokenizer=tokenizer, 
        add_special_tokens=add_special_tokens
    )

    return results 