# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

"""Sample Generate GPT."""
import functools
import os
import sys
import warnings

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))

import torch
from datasets import load_dataset

from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.checkpointing import load_modelopt_checkpoint
from megatron.post_training.generate import simple_generate
from megatron.post_training.model_provider import model_provider
from megatron.post_training.utils import report_current_memory_info
from megatron.training import get_args, get_model, get_tokenizer, initialize_megatron
from megatron.training.utils import print_rank_0, unwrap_model

warnings.filterwarnings('ignore')


def add_mmlu_args(parser):
    """Add additional arguments for ModelOpt text generation PTQ."""
    group = parser.add_argument_group(title='ModelOpt text generation ptq')
    group.add_argument("--disable-tqdm", action="store_true", help="Disable tqdm.")
    group.add_argument("--percentage", type=float, default=1.0)
    group.add_argument("--lower-bound", type=float, default=None)
    add_modelopt_args(parser)
    return parser


def get_all_subjects():
    """Return all MMLU subjects."""
    return [
        'abstract_algebra',
        'anatomy',
        'astronomy',
        'business_ethics',
        'clinical_knowledge',
        'college_biology',
        'college_chemistry',
        'college_computer_science',
        'college_mathematics',
        'college_medicine',
        'college_physics',
        'computer_security',
        'conceptual_physics',
        'econometrics',
        'electrical_engineering',
        'elementary_mathematics',
        'formal_logic',
        'global_facts',
        'high_school_biology',
        'high_school_chemistry',
        'high_school_computer_science',
        'high_school_european_history',
        'high_school_geography',
        'high_school_government_and_politics',
        'high_school_macroeconomics',
        'high_school_mathematics',
        'high_school_microeconomics',
        'high_school_physics',
        'high_school_psychology',
        'high_school_statistics',
        'high_school_us_history',
        'high_school_world_history',
        'human_aging',
        'human_sexuality',
        'international_law',
        'jurisprudence',
        'logical_fallacies',
        'machine_learning',
        'management',
        'marketing',
        'medical_genetics',
        'miscellaneous',
        'moral_disputes',
        'moral_scenarios',
        'nutrition',
        'philosophy',
        'prehistory',
        'professional_accounting',
        'professional_law',
        'professional_medicine',
        'professional_psychology',
        'public_relations',
        'security_studies',
        'sociology',
        'us_foreign_policy',
        'virology',
        'world_religions',
    ]


def format_example(example, include_answer: bool = True):
    """Format an example into a multi-choices problem."""
    prompt = example["question"]
    for choice, answer in zip(["A", "B", "C", "D"], example["choices"]):
        prompt += "\n{}. {}".format(choice, answer)
    if include_answer:
        prompt += "Answer: {}\n\n".format(example["answer"])
    else:
        prompt += "\nAnswer:"
    return prompt


def generate_prompt(test_example, dev_examples, few_shots=0):
    """Generating few-shot prompts."""
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        " ".join(test_example["subject"].split("_"))
    )
    for i in range(few_shots):
        prompt += format_example(dev_examples[i])
    prompt += format_example(test_example, include_answer=False)
    return prompt


if __name__ == "__main__":
    initialize_megatron(
        extra_args_provider=add_mmlu_args,
        args_defaults={
            'tokenizer_type': 'HuggingFaceTokenizer',
            'no_load_rng': True,
            'no_load_optim': True,
        },
    )

    args = get_args()

    # Meta device initialization for ParallelLinear only works if using cpu initialization.
    # Meta device initialization is used such that models can be materialized in low-precision
    # directly when ModelOpt real quant is used. Otherwise, the model is first initialized
    # as BF16 in memory which may result in OOM and defeat the purpose of real quant.
    if args.init_model_with_meta_device:
        args.use_cpu_initialization = True
    else:
        warnings.warn(
            "--init-model-with-meta-device is not set. If you would like to resume the "
            "model in low-bit directly (low-memory initialization and skipping 16-bit), "
            "--init-model-with-meta-device must be set.",
            UserWarning,
        )

    model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False)
    report_current_memory_info()

    # Materialize the model from meta device to gpu before loading the checkpoint. 
    unwrapped_model = unwrap_model(model)[0]
    unwrapped_model.to_empty(device="cuda")
    report_current_memory_info()

    disable_tqdm = args.disable_tqdm or torch.distributed.get_rank() > 0

    tokenizer = get_tokenizer()._tokenizer

    if args.load is not None:
        load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights)
        print_rank_0("Done loading checkpoint")

    all_subjects = get_all_subjects()

    all_correct = {}

    for subject in all_subjects:
        test_data = load_dataset("cais/mmlu", subject, split="test")
        dev_data = load_dataset("cais/mmlu", subject, split="dev")

        correct = []
        for idx, test_example in enumerate(test_data):
            if idx > args.percentage * len(test_data):
                break
            prompt = generate_prompt(test_example, dev_data, few_shots=0)
            label = ["A", "B", "C", "D"][test_example["answer"]]
            tokens = tokenizer(prompt, return_tensors="pt")
            generated_ids = simple_generate(
                unwrapped_model, tokens.input_ids.cuda(), osl=2, disable_tqdm=disable_tqdm
            )
            predict = tokenizer.batch_decode(generated_ids)[0].strip()
            correct += [True] if predict.startswith(label) else [False]
        all_correct[subject] = correct

        if torch.distributed.get_rank() == 0:
            print(
                "{:48}| {:.3f} | {:5}/{:5}".format(
                    subject, sum(correct) / len(correct), sum(correct), len(correct)
                ),
                flush=True,
            )

    avg_correct = []

    for subject, correct in all_correct.items():
        avg_correct += correct

    if torch.distributed.get_rank() == 0:
        print(
            "{:48}| {:.3f} | {:5}/{:5}".format(
                "average", sum(avg_correct) / len(avg_correct), sum(avg_correct), len(avg_correct)
            ),
            flush=True,
        )

    if args.lower_bound is not None:
        assert sum(avg_correct) / len(avg_correct) > args.lower_bound
