import time

import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from tqdm import tqdm

from utils.flops_utils import FLOPsModelWrapper
from utils.prompter import (
    GSM8KPrompter,
    OpenBookQAPrompter,
    PIQAPrompter,
    MMLUPrompter,
    BoolQPrompter,
    SIQAPrompter,
    HellaswagPrompter,
    MedQAPrompter,
    ArcPrompter,
    WinograndePrompter
)

class IndexDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(self.tensors)

def process_ppl_data(samples, tokenizer, seq_len, field_name):
    test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0]
    test_ids_batch = []
    stride = seq_len // 2
    nsamples = test_ids.numel() // stride

    for i in range(nsamples):
        begin = i * stride
        end = begin + seq_len

        if end > test_ids.numel():
            break

        batch = test_ids[begin:end]
        test_ids_batch.append(batch)
    test_ids_batch = torch.stack(test_ids_batch)
    return IndexDataset(tensors=test_ids_batch)

def process_acc_data(samples, tokenizer, prompter, sqlen):
    def tokenize(item):
        prompt_item = prompter.generate_prompt(item)
        full_prompt = prompt_item["prompt"] + prompt_item["answer"]
        test_id = tokenizer(
            full_prompt,
            truncation=True,
            padding="max_length",
            max_length=sqlen,
            return_tensors='pt'
        ).input_ids[0]
        return test_id

    samples = list(map(tokenize, samples.to_list()))
    test_ids_batch = torch.stack(samples)
    return IndexDataset(tensors=test_ids_batch)

def get_mmlu_grouped_loaders(tokenizer, seq_len=128, batch_size=4):
    SUBJECTS = {
        "abstract_algebra": "stem",
        "anatomy": "stem",
        "astronomy": "stem",
        "business_ethics": "other",
        "clinical_knowledge": "other",
        "college_biology": "stem",
        "college_chemistry": "stem",
        "college_computer_science": "stem",
        "college_mathematics": "stem",
        "college_medicine": "other",
        "college_physics": "stem",
        "computer_security": "stem",
        "conceptual_physics": "stem",
        "econometrics": "social_sciences",
        "electrical_engineering": "stem",
        "elementary_mathematics": "stem",
        "formal_logic": "humanities",
        "global_facts": "other",
        "high_school_biology": "stem",
        "high_school_chemistry": "stem",
        "high_school_computer_science": "stem",
        "high_school_european_history": "humanities",
        "high_school_geography": "social_sciences",
        "high_school_government_and_politics": "social_sciences",
        "high_school_macroeconomics": "social_sciences",
        "high_school_mathematics": "stem",
        "high_school_microeconomics": "social_sciences",
        "high_school_physics": "stem",
        "high_school_psychology": "social_sciences",
        "high_school_statistics": "stem",
        "high_school_us_history": "humanities",
        "high_school_world_history": "humanities",
        "human_aging": "other",
        "human_sexuality": "social_sciences",
        "international_law": "humanities",
        "jurisprudence": "humanities",
        "logical_fallacies": "humanities",
        "machine_learning": "stem",
        "management": "other",
        "marketing": "other",
        "medical_genetics": "other",
        "miscellaneous": "other",
        "moral_disputes": "humanities",
        "moral_scenarios": "humanities",
        "nutrition": "other",
        "philosophy": "humanities",
        "prehistory": "humanities",
        "professional_accounting": "other",
        "professional_law": "humanities",
        "professional_medicine": "other",
        "professional_psychology": "social_sciences",
        "public_relations": "social_sciences",
        "security_studies": "social_sciences",
        "sociology": "social_sciences",
        "us_foreign_policy": "social_sciences",
        "virology": "other",
        "world_religions": "humanities",
    }

    CATEGORY_NAMES = sorted(set(SUBJECTS.values()))  # ['humanities', 'other', 'social_sciences', 'stem']

    """返回 {category: DataLoader} 的字典"""
    test_data = load_dataset("cais/mmlu", "all", split="test")

    loaders = {}
    for cat in CATEGORY_NAMES:
        # DataSet.filter 必须捕获当前 cat，否则会全变成最后一次循环的值
        cat_ds = test_data.filter(
            lambda x, c=cat: SUBJECTS.get(x["subject"], "other") == c
        )

        if len(cat_ds) == 0:  # 保险判断，个别类别可能为空
            continue

        # 复用你已有的流程，把每个子集转成 IndexDataset ➜ DataLoader
        idx_ds = process_acc_data(cat_ds, tokenizer, MMLUPrompter(), seq_len)
        loaders[cat] = DataLoader(idx_ds, batch_size=batch_size, shuffle=False)

    return loaders

def get_test_loaders(name, tokenizer, seq_len=2048, batch_size=8):
    test_dataset = []
    if 'wikitext2' in name:
        test_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
        test_dataset = process_ppl_data(test_data, tokenizer, seq_len, 'text')
    elif 'ptb' in name:
        test_data = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
        test_dataset = process_ppl_data(test_data, tokenizer, seq_len, 'sentence')
    elif 'openbookqa' in name:
        test_data = load_dataset('openbookqa', 'main', split='test')
        test_dataset = process_acc_data(test_data, tokenizer, OpenBookQAPrompter(), seq_len)
    elif 'gsm8k' in name:
        test_data = load_dataset('gsm8k', 'main', split='test')
        test_dataset = process_acc_data(test_data, tokenizer, GSM8KPrompter(), seq_len)
    elif 'piqa' in name:
        test_data = load_dataset('piqa', 'plain_text', split='test')
        test_dataset = process_acc_data(test_data, tokenizer, PIQAPrompter(), seq_len)
    elif 'mmlu' in name:
        test_data = load_dataset("cais/mmlu", "all", split='test')
        test_dataset = process_acc_data(test_data, tokenizer, MMLUPrompter(), seq_len)
    elif 'boolq' in name:
        test_data = load_dataset("super_glue", "boolq", split='test')
        test_dataset = process_acc_data(test_data, tokenizer, BoolQPrompter(), seq_len)
    elif 'social_iqa' in name:
        test_data = load_dataset("social_i_qa", split='validation')
        test_dataset = process_acc_data(test_data, tokenizer, SIQAPrompter(), seq_len)
    elif 'hellaswag' in name:
        test_data = load_dataset("hellaswag", split='validation')
        test_dataset = process_acc_data(test_data, tokenizer, HellaswagPrompter(), seq_len)
    elif 'medqa_4options' in name:
        test_data = load_dataset("GBaker/MedQA-USMLE-4-options-hf", split='test')
        test_dataset = process_acc_data(test_data, tokenizer, MedQAPrompter(), seq_len)
    elif 'arc_c' in name:
        test_data = load_dataset("allenai/ai2_arc", name="ARC-Challenge", split='test')
        test_dataset = process_acc_data(test_data, tokenizer, ArcPrompter(), seq_len)
    elif 'arc_e' in name:
        test_data = load_dataset("allenai/ai2_arc", name="ARC-Easy", split='test')
        test_dataset = process_acc_data(test_data, tokenizer, ArcPrompter(), seq_len)
    elif 'winogrande' in name:
        test_data = load_dataset("winogrande", name="winogrande_xl", split="test")
        test_dataset = process_acc_data(test_data, tokenizer, WinograndePrompter(), seq_len)

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader

def compute_time(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        time_cost = time.time() - start_time
        return result, time_cost

    return wrapper

@torch.no_grad()
@compute_time
def llama_eval(model, dataset, test_loader, device):
    nlls = []
    decisions = [0, 0]  # sum, count
    n_samples = 0
    total_flops = 0.0

    wrapper_model = FLOPsModelWrapper(model)

    bar_format = "Calculating PPL for " + dataset + ":" + "{l_bar}{bar}{r_bar}"
    for batch in tqdm(test_loader, bar_format=bar_format, ncols=100):
        batch = batch.to(device)

        output, flops, params = wrapper_model.compute_flops(batch)

        # output = model(batch)
        if type(output) is tuple:
            lm_logits = output[0]
            activation = output[1]
            decisions[0] += activation.sum()
            decisions[1] += activation.numel()
        else:
            lm_logits = output.logits

        n_samples += batch.shape[0]  # compute n_samples
        total_flops += flops  # total_flops

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()

        loss_fc = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fc(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        nlls.append(loss)

        # print(f"flops={flops / 1e9}, total_flops={total_flops / 1e9}, n_samples={n_samples}")

    ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
    avg_flops = total_flops / n_samples / 1e9
    decision_result = (decisions[0] / decisions[1]).item() if decisions[1] > 0 else 0
    return (ppl.item(), avg_flops, n_samples, decision_result)

def Metric(model, tokenizer, datasets, seq_len=128, batch_size=4, device="cuda"):
    metric = {}
    for dataset in datasets:
        if dataset == "mmlu":
            # 针对 mmlu，拆分 4 个类别分别评估
            grouped_loaders = get_mmlu_grouped_loaders(tokenizer, seq_len, batch_size)
            for cat, loader in grouped_loaders.items():
                sub_name = f"{dataset}_{cat}"  # e.g. mmlu_stem
                (ppl, flops, n_samples, decision_result), time_cost = llama_eval(model, sub_name, loader, device)
                metric[sub_name] = dict(ppl=ppl, flops=flops, n_samples=n_samples, time_cost=time_cost,
                                        decisions=decision_result)
        else:
            test_loader = get_test_loaders(dataset, tokenizer, seq_len=seq_len, batch_size=batch_size)
            (ppl, flops, n_samples, decision_result), time_cost = llama_eval(model, dataset, test_loader, device)
            metric[dataset] = dict(ppl=ppl, flops=flops, n_samples=n_samples, time_cost=time_cost,
                                   decisions=decision_result)

        print(metric)
    return metric
