"""
Downloads and evaluates HellaSwag in Python.
https://github.com/rowanz/hellaswag

Example HellaSwag json item:

{"ind": 24, "activity_label": "Roof shingle removal", "ctx_a": "A man is sitting on a roof.", "ctx_b": "he", "ctx": "A man is sitting on a roof. he", "split": "val", "split_type": "indomain", "label": 3, "endings": ["is using wrap to wrap a pair of skis.", "is ripping level tiles off.", "is holding a rubik's cube.", "starts pulling up roofing on a roof."], "source_id": "activitynet~v_-JhWjGDPHMY"}

ind: dataset ID
activity_label: The ActivityNet or WikiHow label for this example
context: There are two formats. The full context is in ctx. When the context ends in an (incomplete) noun phrase, like for ActivityNet, this incomplete noun phrase is in ctx_b, and the context up until then is in ctx_a. This can be useful for models such as BERT that need the last sentence to be complete. However, it's never required. If ctx_b is nonempty, then ctx is the same thing as ctx_a, followed by a space, then ctx_b.
endings: a list of 4 endings. The correct index is given by label (0,1,2, or 3)
split: train, val, or test.
split_type: indomain if the activity label is seen during training, else zeroshot
source_id: Which video or WikiHow article this example came from

gpt2 (124M)
- eleuther harness reports acc 28.92%, acc_norm 31.14% (multiple choice style)
- this script: 10042 acc: 0.2859 acc_norm: 0.2955 (completion style)

gpt2-xl (1558M)
- eleuther harness reports acc 40.04%, acc_norm 50.89% (multiple choice style)
- this script: 10042 acc: 0.3842 acc_norm: 0.4893 (completion style)

The validation set of HellaSwag has a total of 10,042 examples.
"""

import os
import json
import requests
import tiktoken
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.nn import functional as F
from safetensors.torch import load_model

import click

# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "hellaswag")


def download_file(url: str, fname: str, chunk_size=1024):
    """Helper function to download a file from a given url"""
    resp = requests.get(url, stream=True)
    total = int(resp.headers.get("content-length", 0))
    with open(fname, "wb") as file, tqdm(
        desc=fname,
        total=total,
        unit="iB",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in resp.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            bar.update(size)


hellaswags = {
    "train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
    "val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl",
    "test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl",
}

enc = tiktoken.get_encoding("gpt2")


def download(split):
    """Downloads HellaSwag DATA_CACHE_DIR"""
    os.makedirs(DATA_CACHE_DIR, exist_ok=True)
    data_url = hellaswags[split]
    data_filename = os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl")
    if not os.path.exists(data_filename):
        print(f"Downloading {data_url} to {data_filename}...")
        download_file(data_url, data_filename)


def render_example(example):
    """
    Given the example as a dictionary, render it as three torch tensors:
    - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)
    - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)
    - label (the index of the correct completion, which we hope has the highest likelihood)
    """
    ctx = example["ctx"]
    label = example["label"]
    endings = example["endings"]

    # data needed to reproduce this eval on the C size
    data = {
        "label": label,
        "ctx_tokens": None,
        "ending_tokens": [],
    }

    # gather up all the tokens
    ctx_tokens = enc.encode(ctx)
    data["ctx_tokens"] = ctx_tokens
    tok_rows = []
    mask_rows = []
    for end in endings:
        end_tokens = enc.encode(
            " " + end
        )  # note: prepending " " because GPT-2 tokenizer
        tok_rows.append(ctx_tokens + end_tokens)
        mask_rows.append([0] * len(ctx_tokens) + [1] * len(end_tokens))
        data["ending_tokens"].append(end_tokens)

    # have to be careful during the collation because the number of tokens in each row can differ
    max_len = max(len(row) for row in tok_rows)
    tokens = torch.zeros((4, max_len), dtype=torch.long)
    mask = torch.zeros((4, max_len), dtype=torch.long)
    for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
        tokens[i, : len(tok_row)] = torch.tensor(tok_row)
        mask[i, : len(mask_row)] = torch.tensor(mask_row)

    return data, tokens, mask, label


def iterate_examples(split):
    # there are 10,042 examples in total in val
    download(split)
    with open(
        os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl"), "r"
    ) as f:
        for line in f:
            example = json.loads(line)
            yield example


@torch.no_grad()
@click.command()
@click.option("--ckpt_dir", help="checkpoint directory")
@click.option("--quantize_parameters", default=None, help="quantize parameters")
@click.option("--init_bit", default=6.0, help="init bit")
@click.option("--target_bit", default=4.0, help="target bit")
@click.option(
    "--threshold_bit_high", default=6.0, help="threshold FP12_e4m7 vs. FP8_e3m4"
)
@click.option(
    "--threshold_bit_middle",
    default=3.0,
    help="threshold FP8_e3m4 vs. FP4_e2m1",
)
@click.option("--device", default="cuda")
def evaluate(
    ckpt_dir,
    quantize_parameters,
    init_bit,
    target_bit,
    threshold_bit_high,
    threshold_bit_middle,
    device,
):
    assert quantize_parameters in [
        None,
        "fp8_e4m3",
        "fp8_e4m3_block",
        "fp8_e3m4",
        "fp8_e3m4_block",
        "fp12_e4m7",
        "fp12_e4m7_block",
        "fp8_e3m4_fp12_e4m7_mixed_block",
        "fp4_e2m1_fp8_e3m4_fp12_e4m7_mix3d_block",
    ]
    model_type = "vanilla"
    is_diffq = False
    if "DiffFPQ" in ckpt_dir or "MXFP" in ckpt_dir:
        model_type = "quant"
    elif "DiffQ" in ckpt_dir:
        model_type = "quant"
        is_diffq = True

    if model_type == "vanilla":
        from model import GPTConfig, GPT
    else:
        from qmodel import GPTConfig, GPT

        import inspect
        import sys

        currentdir = os.path.dirname(
            os.path.abspath(inspect.getfile(inspect.currentframe()))
        )
        pdir = os.path.dirname(currentdir)
        sys.path.insert(0, os.path.join(pdir, "mx-amp-kernel"))

        from q_config import QConfig

    torch.set_float32_matmul_precision("high")  # use tf32
    ckpt_path = os.path.join(ckpt_dir, "ckpt.pt")
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint["model_args"]
    gptconf = GPTConfig(**checkpoint_model_args)
    model = GPT(gptconf)
    if model_type == "quant":
        load_model(model, os.path.join(ckpt_dir, "model.safetensors"))
    else:
        state_dict = checkpoint["model"]
        unwanted_prefix = "_orig_mod."
        for k, v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
        model.load_state_dict(state_dict)

    # quantize model parameters via microscaling (MX) format
    if quantize_parameters is not None:
        import sys
        import inspect

        # import os

        currentdir = os.path.dirname(
            os.path.abspath(inspect.getfile(inspect.currentframe()))
        )
        pdir = os.path.dirname(currentdir)
        sys.path.insert(0, pdir)

        import mx
        import re

        state_dict = model.state_dict()

        if "block" in quantize_parameters:
            custom_cuda = False
            axes = [-1, -2]
            quantize_parameters = "_".join(quantize_parameters.split("_")[:-1])
        else:
            custom_cuda = True
            axes = [-1]
        mx_specs = dict(
            scale_bits=8,
            shared_exp_method="max",
            mx_flush_fp32_subnorms=True,
            block_size=32,
            custom_cuda=custom_cuda,
        )
        total_high_count = 0
        total_middle_count = 0
        total_low_count = 0
        for name, param in state_dict.items():
            if re.search("transformer.h.*.attn.*.weight", name) or re.search(
                "transformer.h.*.mlp.*.weight", name
            ):
                if "mixed" in quantize_parameters:
                    fp_1, em_1, fp_2, em_2, _ = quantize_parameters.split("_")
                    low = "_".join([fp_1, em_1])
                    high = "_".join([fp_2, em_2])
                    middle = low
                    threshold_bit_middle = threshold_bit_high
                elif "mix3d" in quantize_parameters:
                    fp_1, em_1, fp_2, em_2, fp_3, em_3, _ = (
                        quantize_parameters.split("_")
                    )
                    low = "_".join([fp_1, em_1])
                    middle = "_".join([fp_2, em_2])
                    high = "_".join([fp_3, em_3])
                else:
                    low = middle = high = quantize_parameters
                transformed_param_low = mx.mx_ops.quantize_mx_op(
                    param,
                    mx_specs,
                    elem_format=low,
                    block_size=32,
                    axes=axes,
                    round="nearest",
                )
                transformed_param_middle = mx.mx_ops.quantize_mx_op(
                    param,
                    mx_specs,
                    elem_format=middle,
                    block_size=32,
                    axes=axes,
                    round="nearest",
                )
                transformed_param_high = mx.mx_ops.quantize_mx_op(
                    param,
                    mx_specs,
                    elem_format=high,
                    block_size=32,
                    axes=axes,
                    round="nearest",
                )

                mask_high = torch.zeros_like(param, dtype=torch.bool)
                mask_middle = torch.zeros_like(param, dtype=torch.bool)
                if (
                    "mixed" in quantize_parameters
                    or "mix3d" in quantize_parameters
                ):
                    high_count = 0
                    middle_count = 0
                    low_count = 0
                    for name2, param2 in state_dict.items():
                        if name2 == ".".join(
                            [*name.split(".")[:-1], "wgt_sampler", "bit"]
                        ):
                            # param2 into mask
                            assert len(param2.shape) == 2
                            high_mask = (
                                param2 * (init_bit - target_bit) + target_bit
                                > threshold_bit_high
                            )
                            middle_mask = (
                                param2 * (init_bit - target_bit) + target_bit
                                > threshold_bit_middle
                            )
                            high_count += torch.sum(high_mask)
                            middle_count += torch.sum(middle_mask) - torch.sum(
                                high_mask
                            )
                            low_count += param2.shape[0] * param2.shape[
                                1
                            ] - torch.sum(middle_mask)
                            mask_high = (
                                high_mask.view(
                                    (param2.shape[0], 1, param2.shape[1], 1)
                                )
                                .broadcast_to(
                                    (param2.shape[0], 32, param2.shape[1], 32)
                                )
                                .reshape(
                                    (param2.shape[0] * 32, param2.shape[1] * 32)
                                )
                            )
                            mask_middle = (
                                middle_mask.view(
                                    param2.shape[0], 1, param2.shape[1], 1
                                )
                                .broadcast_to(
                                    param2.shape[0], 32, param2.shape[1], 32
                                )
                                .reshape(
                                    param2.shape[0] * 32, param2.shape[1] * 32
                                )
                            )
                            break
                    total_high_count += high_count
                    total_middle_count += middle_count
                    total_low_count += low_count
                if "mixed" in quantize_parameters:
                    total_count = high_count + low_count
                    print(
                        f"{name}:\t{high} {high_count / total_count}\t{low} {low_count / total_count}"
                    )
                elif "mix3d" in quantize_parameters:
                    total_count = high_count + middle_count + low_count
                    print(
                        f"{name}:\t{high} {high_count / total_count}\t{middle} {middle_count / total_count}\t{low} {low_count / total_count}"
                    )

                transformed_param = torch.where(
                    mask_high,
                    transformed_param_high,
                    torch.where(
                        mask_middle,
                        transformed_param_middle,
                        transformed_param_low,
                    ),
                )
                param.copy_(transformed_param)
        if "mixed" in quantize_parameters:
            total_count = total_high_count + total_low_count
            print(
                f"{high}:\t{total_high_count / total_count}\t{low}:\t{total_low_count / total_count}"
            )
        elif "mix3d" in quantize_parameters:
            total_count = (
                total_high_count + total_middle_count + total_low_count
            )
            print(
                f"{high}:\t{total_high_count / total_count}\t{middle}:\t{total_middle_count / total_count}\t{low}:\t{total_low_count / total_count}"
            )
    # model = GPT2LMHeadModel.from_pretrained(model_type)
    model.to(device)
    # model = torch.compile(model) # optionally torch compile the model
    model.eval()

    num_correct_norm = 0
    num_correct = 0
    num_total = 0
    for example in iterate_examples("val"):
        data, tokens, mask, label = render_example(example)
        tokens = tokens.to(device)
        mask = mask.to(device)

        # get the logits
        logits, loss = model(tokens, targets=tokens)
        # evaluate the autoregressive loss at all positions
        shift_logits = (logits[..., :-1, :]).contiguous()
        shift_tokens = (tokens[..., 1:]).contiguous()
        flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
        flat_shift_tokens = shift_tokens.view(-1)
        shift_losses = F.cross_entropy(
            flat_shift_logits, flat_shift_tokens, reduction="none"
        )
        shift_losses = shift_losses.view(tokens.size(0), -1)
        # now get the average loss just for the completion region (where mask == 1), in each row
        shift_mask = (
            mask[..., 1:]
        ).contiguous()  # we must shift mask, so we start at the last prompt token
        masked_shift_losses = shift_losses * shift_mask
        # sum and divide by the number of 1s in the mask
        sum_loss = masked_shift_losses.sum(dim=1)
        avg_loss = sum_loss / shift_mask.sum(dim=1)
        # now we have a loss for each of the 4 completions
        # the one with the lowest loss should be the most likely
        pred = sum_loss.argmin().item()
        pred_norm = avg_loss.argmin().item()

        # accumulate stats
        num_total += 1
        num_correct += int(pred == label)
        num_correct_norm += int(pred_norm == label)
        if num_total % 100 == 0:
            print(
                f"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}"
            )

        # debug: pretty print a few examples, and the losses in each case
        if num_total < 10:
            print("---")
            print(f"Context:\n {example['ctx']}")
            print(f"Endings:")
            for i, end in enumerate(example["endings"]):
                print(f"{i} (loss: {avg_loss[i].item():.4f}) {end}")
            print(f"predicted: {pred_norm}, actual: {label}")
    print(
        f"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}"
    )


if __name__ == "__main__":
    evaluate()
