import csv
import json
import logging
from typing import List

import datasets
import fire
import torch

import moe_peft

choices_map = ["A", "B", "C", "D"]


def format_subject(subject):
    lst = subject.split("_")
    sjt = ""
    for entry in lst:
        sjt += " " + entry
    return sjt


def format_prompt(data_point, with_answer=True):
    question = data_point["question"].strip()
    choices = "".join(
        [
            f"{key}. {choice}\n"
            for key, choice in zip(choices_map, data_point["choices"])
        ]
    )
    prompt = f"{question}\n{choices}Answer:"
    if with_answer:
        prompt += " {}\n\n".format(choices_map[data_point["answer"]])
    return prompt


def prepare_data(
    tokenizer: moe_peft.Tokenizer,
    subject: str,
    dev_data: datasets.Dataset,
    test_data: datasets.Dataset,
    k_shots=5,
    max_seq_len=2048,
    batch_padding=True,
):

    sequence_lengths = []
    batch_tokens = []
    batch_labels = []
    atten_masks = []

    max_tokens_len = 0
    tokens = None
    for test_data_point in test_data:
        test_prompt = format_prompt(test_data_point, False)
        dev_prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
            format_subject(subject)
        )
        k = k_shots
        for dev_data_point in dev_data:
            k -= 1
            prompt = format_prompt(dev_data_point)
            input_ids = tokenizer.encode(dev_prompt + prompt + test_prompt)
            if len(input_ids) <= max_seq_len:
                tokens = input_ids
                dev_prompt += prompt
            else:
                k = 0

            if k <= 0:
                break

        max_tokens_len = max(len(tokens), max_tokens_len)
        batch_tokens.append(tokens)
        batch_labels.append(test_data_point["answer"])

    if batch_padding:
        max_seq_len = min(max_seq_len, max_tokens_len)
        logging.info(f"Max sequence length: {max_seq_len}")

    for tokens in batch_tokens:
        if batch_padding:
            sequence_lengths.append(len(tokens) - 1)
            while len(tokens) < max_seq_len:
                tokens.append(tokenizer.pad_id_)
        else:
            sequence_lengths.append(-1)
        atten_masks.append(tokenizer.mask_from(tokens))

    return sequence_lengths, batch_tokens, atten_masks, batch_labels


@torch.inference_mode()
def evaluate(
    subject: str,
    tokenizer: moe_peft.Tokenizer,
    model: moe_peft.LLMModel,
    adapter_names: List[str],
    batch_size: int = 2,
    max_seq_len: int = 2048,
):
    # prepare data

    mmlu = datasets.load_dataset("cais/mmlu", subject)

    sequence_lengths, batch_tokens, atten_masks, batch_labels = prepare_data(
        tokenizer, subject, mmlu["dev"], mmlu["test"], 5, max_seq_len, batch_size > 1
    )

    # load adapters

    results = {}

    for name in adapter_names:
        results[name] = []

    # prepare for evaluate
    sequence_lengths = torch.tensor(
        sequence_lengths, dtype=torch.long, device=model.device_
    )

    label_indices = [0] * len(choices_map)
    for idx, text in enumerate(choices_map):
        ids = tokenizer.encode(text)
        label_indices[idx] = ids[-1]
    label_indices = torch.tensor(label_indices, dtype=torch.long, device=model.device_)

    start_pos = 0
    while start_pos < len(batch_tokens):
        end_pos = min(len(batch_tokens), start_pos + batch_size)
        logging.info(f"evaluation step: {start_pos}/{len(batch_tokens)}")
        bsz = end_pos - start_pos
        batch_data_config = []
        batch_start_idx = 0
        for name in adapter_names:
            batch_data_config.append(
                moe_peft.LLMBatchConfig(
                    adapter_name_=name,
                    batch_start_idx_=batch_start_idx,
                    batch_end_idx_=batch_start_idx + bsz,
                )
            )
            batch_start_idx += bsz

        input_args = moe_peft.LLMModelInput(
            batch_configs_=batch_data_config,
            batch_tokens_=batch_tokens[start_pos:end_pos] * len(adapter_names),
            batch_masks_=atten_masks[start_pos:end_pos] * len(adapter_names),
            inference_mode_=True,
        )

        outputs = model.forward(input_args)

        labels = torch.tensor(
            batch_labels[start_pos:end_pos], dtype=torch.long, device=model.device_
        )

        for output in outputs:
            logits = output.logits
            logits = logits[
                torch.arange(bsz, device=logits.device),
                sequence_lengths[start_pos:end_pos],
            ]
            logits = logits[:, label_indices]
            logits = logits.softmax(-1).argmax(-1)
            result = (logits == labels).int().tolist()
            results[output.adapter_name].extend(result)

        for name, result in results.items():
            acc = sum(result) / len(result)
            logging.info(f"    {name} accuracy: {acc}")

        start_pos = end_pos

    return results


mmlu_subcategories = {
    "abstract_algebra": ["math"],
    "anatomy": ["health"],
    "astronomy": ["physics"],
    "business_ethics": ["business"],
    "clinical_knowledge": ["health"],
    "college_biology": ["biology"],
    "college_chemistry": ["chemistry"],
    "college_computer_science": ["computer science"],
    "college_mathematics": ["math"],
    "college_medicine": ["health"],
    "college_physics": ["physics"],
    "computer_security": ["computer science"],
    "conceptual_physics": ["physics"],
    "econometrics": ["economics"],
    "electrical_engineering": ["engineering"],
    "elementary_mathematics": ["math"],
    "formal_logic": ["philosophy"],
    "global_facts": ["other"],
    "high_school_biology": ["biology"],
    "high_school_chemistry": ["chemistry"],
    "high_school_computer_science": ["computer science"],
    "high_school_european_history": ["history"],
    "high_school_geography": ["geography"],
    "high_school_government_and_politics": ["politics"],
    "high_school_macroeconomics": ["economics"],
    "high_school_mathematics": ["math"],
    "high_school_microeconomics": ["economics"],
    "high_school_physics": ["physics"],
    "high_school_psychology": ["psychology"],
    "high_school_statistics": ["math"],
    "high_school_us_history": ["history"],
    "high_school_world_history": ["history"],
    "human_aging": ["health"],
    "human_sexuality": ["culture"],
    "international_law": ["law"],
    "jurisprudence": ["law"],
    "logical_fallacies": ["philosophy"],
    "machine_learning": ["computer science"],
    "management": ["business"],
    "marketing": ["business"],
    "medical_genetics": ["health"],
    "miscellaneous": ["other"],
    "moral_disputes": ["philosophy"],
    "moral_scenarios": ["philosophy"],
    "nutrition": ["health"],
    "philosophy": ["philosophy"],
    "prehistory": ["history"],
    "professional_accounting": ["other"],
    "professional_law": ["law"],
    "professional_medicine": ["health"],
    "professional_psychology": ["psychology"],
    "public_relations": ["politics"],
    "security_studies": ["politics"],
    "sociology": ["culture"],
    "us_foreign_policy": ["politics"],
    "virology": ["health"],
    "world_religions": ["philosophy"],
}


mmlu_categories = {
    "STEM": [
        "physics",
        "chemistry",
        "biology",
        "computer science",
        "math",
        "engineering",
    ],
    "humanities": ["history", "philosophy", "law"],
    "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
    "other (business, health, misc.)": ["other", "business", "health"],
}


model_dtypes = {
    "4bit": {"bits": 4, "load_dtype": torch.float32},
    "8bit": {"bits": 8, "load_dtype": torch.float32},
    "16bit": {"load_dtype": torch.bfloat16},
}


def do_evaluate(
    model_name: str,
    model_dtype: str,
    adapter_names: List[str],
    batch_size: int = 2,
    device: str = moe_peft.executor.default_device_name(),
    output: str = "mmlu_scores.csv",
):
    tokenizer = moe_peft.Tokenizer(model_name)
    model = moe_peft.LLMModel.from_pretrained(
        model_name, device=device, **model_dtypes[model_dtype]
    )
    for name in adapter_names:
        logging.info(f"Loading adapter {name}")
        if name == "default":
            model.init_adapter(moe_peft.AdapterConfig(adapter_name=name))
        else:
            model.load_adapter(name)

    csv_data = [["mmlu_categories", "mmlu_subcategories", "adapter_name", "acc_score"]]
    for subject, subcategory in mmlu_subcategories.items():
        logging.info(f"Performing MMLU/{subject} Benchmark")
        results = evaluate(
            subject,
            tokenizer,
            model,
            adapter_names,
            batch_size,
            model.config_.max_seq_len_,
        )
        category = None
        for category_name, subcategory_names in mmlu_categories.items():
            if subcategory[-1] in subcategory_names:
                category = category_name
        for name, result in results.items():
            acc = sum(result) / len(result)
            csv_data.append([category, subject, name, acc])
        with open(output, "w", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerows(csv_data)


def main(config: str):
    moe_peft.executor.manual_seed(66)
    moe_peft.setup_logging("INFO")
    if not moe_peft.executor.check_available():
        exit(-1)
    with open(config, "r", encoding="utf8") as fp:
        mmlu_config = json.load(fp)
    do_evaluate(**mmlu_config)


if __name__ == "__main__":
    fire.Fire(main)
