"""
This code calulates the perplexity of the model on the test set.
The code is adapted from the original code trainer.py
"""

import os
import pickle
from contextlib import nullcontext

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

from model import GPTConfig, GPT
from tqdm import tqdm

# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
# I/O
out_dir = "gpt2-124M-300B-adamw-vanilla"
init_bit = 6.0
target_bit = 4.0
threshold_bit_high = 6.0
threshold_bit_middle = 3.0
eval_interval = 2000
log_interval = 1
eval_iters = 200
eval_only = False  # if True, script exits right after the first eval
always_save_checkpoint = (
    True  # if True, always save a checkpoint after each eval
)
init_from = "resume"  # 'scratch' or 'resume' or 'gpt2*'
quantize_parameters = "fp8_e3m4_fp12_e4m7_mixed_block"
# wandb logging
wandb_log = False  # disabled by default
wandb_project = "owt"
wandb_run_name = "gpt2"  # 'run' + str(time.time())
# data
dataset = "wikitext-2"
split = "test"
gradient_accumulation_steps = 5 * 8  # used to simulate larger batch sizes
batch_size = (
    12  # if gradient_accumulation_steps > 1, this is the micro-batch size
)
block_size = 1024
# model
n_layer = 12
n_head = 12
n_embd = 768
dropout = 0.0  # for pretraining 0 is good, for finetuning try 0.1+
bias = False  # do we use bias inside LayerNorm and Linear layers?
# adamw optimizer
learning_rate = 6e-4  # max learning rate
max_iters = 600000  # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0  # clip gradients at this value, or disable if == 0.0
# PPL
stride = 64  # sliding window ppl
# learning rate decay settings
decay_lr = True  # whether to decay the learning rate
warmup_iters = 2000  # how many steps to warm up for
lr_decay_iters = 600000  # should be ~= max_iters per Chinchilla
min_lr = (
    6e-5  # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
)
# DDP settings
backend = "nccl"  # 'nccl', 'gloo', etc.
# system
device = "cuda"  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = (
    "bfloat16"
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else "float16"
)  # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = False  # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
config_keys = [
    k
    for k, v in globals().items()
    if not k.startswith("_") and isinstance(v, (int, float, bool, str))
]
exec(
    open("configurator.py").read()
)  # overrides from command line or config file
config = {k: globals()[k] for k in config_keys}  # will be useful for logging
# -----------------------------------------------------------------------------

# check parameters
assert quantize_parameters in [
    None,
    "fp8_e4m3",
    "fp8_e4m3_block",
    "fp8_e3m4",
    "fp8_e3m4_block",
    "fp12_e4m7",
    "fp12_e4m7_block",
    "fp8_e3m4_fp12_e4m7_mixed_block",
    "fp4_e2m1_fp8_e3m4_fp12_e4m7_mix3d_block",
]

# various inits, derived attributes, I/O setup
ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
if ddp:
    init_process_group(backend=backend)
    ddp_rank = int(os.environ["RANK"])
    ddp_local_rank = int(os.environ["LOCAL_RANK"])
    ddp_world_size = int(os.environ["WORLD_SIZE"])
    device = f"cuda:{ddp_local_rank}"
    torch.cuda.set_device(device)
    master_process = (
        ddp_rank == 0
    )  # this process will do logging, checkpointing etc.
    seed_offset = ddp_rank  # each process gets a different seed
    # world_size number of processes will be training simultaneously, so we can scale
    # down the desired gradient accumulation iterations per process proportionally
    assert gradient_accumulation_steps % ddp_world_size == 0
    gradient_accumulation_steps //= ddp_world_size
else:
    # if not ddp, we are running on a single gpu, and one process
    master_process = True
    seed_offset = 0
    ddp_world_size = 1
tokens_per_iter = (
    gradient_accumulation_steps * ddp_world_size * batch_size * block_size
)
print(f"tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
    os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
device_type = (
    "cuda" if "cuda" in device else "cpu"
)  # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
    "float16": torch.float16,
}[dtype]
ctx = (
    nullcontext()
    if device_type == "cpu"
    else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)

# poor man's data loader
data_dir = os.path.join("data", dataset)
train_data = np.memmap(
    os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r"
)
val_data = np.memmap(
    os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r"
)
test_data = np.memmap(
    os.path.join(data_dir, "test.bin"), dtype=np.uint16, mode="r"
)
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9

# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, "meta.pkl")
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, "rb") as f:
        meta = pickle.load(f)
    meta_vocab_size = meta["vocab_size"]
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")

# model init
model_args = dict(
    n_layer=n_layer,
    n_head=n_head,
    n_embd=n_embd,
    block_size=block_size,
    bias=bias,
    vocab_size=None,
    dropout=dropout,
)  # start with model_args from command line
if init_from == "scratch":
    # init a new model from scratch
    print("Initializing a new model from scratch")
    # determine the vocab size we'll use for from-scratch training
    if meta_vocab_size is None:
        print(
            "defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)"
        )
    model_args["vocab_size"] = (
        meta_vocab_size if meta_vocab_size is not None else 50304
    )
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
elif init_from == "resume":
    print(f"Resuming training from {out_dir}")
    model_type = "vanilla"
    is_diffq = False
    if "DiffFPQ" in out_dir or "MXFP" in out_dir:
        model_type = "quant"
    elif "DiffQ" in out_dir:
        model_type = "quant"
        is_diffq = True

    if model_type == "vanilla":
        from model import GPTConfig, GPT
    else:
        from qmodel import GPTConfig, GPT

        import inspect
        import sys

        currentdir = os.path.dirname(
            os.path.abspath(inspect.getfile(inspect.currentframe()))
        )
        pdir = os.path.dirname(currentdir)
        sys.path.insert(0, os.path.join(pdir, "mx-amp-kernel"))

        from q_config import QConfig
        from safetensors.torch import load_model

    ckpt_path = os.path.join(out_dir, "ckpt.pt")
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint["model_args"]
    gptconf = GPTConfig(**checkpoint_model_args)
    model = GPT(gptconf)

    if model_type == "quant":
        load_model(model, os.path.join(out_dir, "model.safetensors"))
    else:
        state_dict = checkpoint["model"]
        unwanted_prefix = "_orig_mod."
        for k, v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
        model.load_state_dict(state_dict)
    # force these config attributes to be equal otherwise we can't even resume training
    # the rest of the attributes (e.g. dropout) can stay as desired from command line
    for k in [
        "n_layer",
        "n_head",
        "n_embd",
        "block_size",
        "bias",
        "vocab_size",
    ]:
        model_args[k] = checkpoint_model_args[k]
    # create the model
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    state_dict = checkpoint["model"]
    # fix the keys of the state dictionary :(
    # honestly no idea how checkpoints sometimes get this prefix, have to debug more
    unwanted_prefix = "_orig_mod."
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint["iter_num"]
    best_val_loss = checkpoint["best_val_loss"]
elif init_from.startswith("gpt2"):
    print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
    # initialize from OpenAI GPT-2 weights
    override_args = dict(dropout=dropout)
    model = GPT.from_pretrained(init_from, override_args)
    # read off the created config params, so we can store them into checkpoint correctly
    for k in [
        "n_layer",
        "n_head",
        "n_embd",
        "block_size",
        "bias",
        "vocab_size",
    ]:
        model_args[k] = getattr(model.config, k)
# crop down the model block size if desired, using model surgery
if block_size < model.config.block_size:
    model.crop_block_size(block_size)
    model_args["block_size"] = (
        block_size  # so that the checkpoint will have the right value
    )
# quantize model parameters via microscaling (MX) format
if quantize_parameters is not None:
    import sys
    import inspect
    import os

    currentdir = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe()))
    )
    pdir = os.path.dirname(currentdir)
    sys.path.insert(0, pdir)

    import mx
    import re

    state_dict = model.state_dict()

    if "block" in quantize_parameters:
        custom_cuda = False
        axes = [-1, -2]
        quantize_parameters = "_".join(quantize_parameters.split("_")[:-1])
    else:
        custom_cuda = True
        axes = [-1]
    mx_specs = dict(
        scale_bits=8,
        shared_exp_method="max",
        mx_flush_fp32_subnorms=True,
        block_size=32,
        custom_cuda=custom_cuda,
    )
    total_high_count = 0
    total_middle_count = 0
    total_low_count = 0
    for name, param in state_dict.items():
        if re.search("transformer.h.*.attn.*.weight", name) or re.search(
            "transformer.h.*.mlp.*.weight", name
        ):
            if "mixed" in quantize_parameters:
                fp_1, em_1, fp_2, em_2, _ = quantize_parameters.split("_")
                low = "_".join([fp_1, em_1])
                high = "_".join([fp_2, em_2])
                middle = low
                threshold_bit_middle = threshold_bit_high
            elif "mix3d" in quantize_parameters:
                fp_1, em_1, fp_2, em_2, fp_3, em_3, _ = (
                    quantize_parameters.split("_")
                )
                low = "_".join([fp_1, em_1])
                middle = "_".join([fp_2, em_2])
                high = "_".join([fp_3, em_3])
            else:
                low = middle = high = quantize_parameters
            transformed_param_low = mx.mx_ops.quantize_mx_op(
                param,
                mx_specs,
                elem_format=low,
                block_size=32,
                axes=axes,
                round="nearest",
            )
            transformed_param_middle = mx.mx_ops.quantize_mx_op(
                param,
                mx_specs,
                elem_format=middle,
                block_size=32,
                axes=axes,
                round="nearest",
            )
            transformed_param_high = mx.mx_ops.quantize_mx_op(
                param,
                mx_specs,
                elem_format=high,
                block_size=32,
                axes=axes,
                round="nearest",
            )

            mask_high = torch.zeros_like(param, dtype=torch.bool)
            mask_middle = torch.zeros_like(param, dtype=torch.bool)
            if "mixed" in quantize_parameters or "mix3d" in quantize_parameters:
                high_count = 0
                middle_count = 0
                low_count = 0
                for name2, param2 in state_dict.items():
                    if name2 == ".".join(
                        [*name.split(".")[:-1], "wgt_sampler", "bit"]
                    ):
                        # param2 into mask
                        assert len(param2.shape) == 2
                        high_mask = (
                            param2 * (init_bit - target_bit) + target_bit
                            > threshold_bit_high
                        )
                        middle_mask = (
                            param2 * (init_bit - target_bit) + target_bit
                            > threshold_bit_middle
                        )
                        high_count += torch.sum(high_mask)
                        middle_count += torch.sum(middle_mask) - torch.sum(
                            high_mask
                        )
                        low_count += param2.shape[0] * param2.shape[
                            1
                        ] - torch.sum(middle_mask)
                        mask_high = (
                            high_mask.view(
                                (param2.shape[0], 1, param2.shape[1], 1)
                            )
                            .broadcast_to(
                                (param2.shape[0], 32, param2.shape[1], 32)
                            )
                            .reshape(
                                (param2.shape[0] * 32, param2.shape[1] * 32)
                            )
                        )
                        mask_middle = (
                            middle_mask.view(
                                param2.shape[0], 1, param2.shape[1], 1
                            )
                            .broadcast_to(
                                param2.shape[0], 32, param2.shape[1], 32
                            )
                            .reshape(param2.shape[0] * 32, param2.shape[1] * 32)
                        )
                        break
                total_high_count += high_count
                total_middle_count += middle_count
                total_low_count += low_count
            if "mixed" in quantize_parameters:
                total_count = high_count + low_count
                print(
                    f"{name}:\t{high} {high_count / total_count}\t{low} {low_count / total_count}"
                )
            elif "mix3d" in quantize_parameters:
                total_count = high_count + middle_count + low_count
                print(
                    f"{name}:\t{high} {high_count / total_count}\t{middle} {middle_count / total_count}\t{low} {low_count / total_count}"
                )

            transformed_param = torch.where(
                mask_high,
                transformed_param_high,
                torch.where(
                    mask_middle,
                    transformed_param_middle,
                    transformed_param_low,
                ),
            )
            param.copy_(transformed_param)
    if "mixed" in quantize_parameters:
        total_count = total_high_count + total_low_count
        print(
            f"{high}:\t{total_high_count / total_count}\t{low}:\t{total_low_count / total_count}"
        )
    elif "mix3d" in quantize_parameters:
        total_count = total_high_count + total_middle_count + total_low_count
        print(
            f"{high}:\t{total_high_count / total_count}\t{middle}:\t{total_middle_count / total_count}\t{low}:\t{total_low_count / total_count}"
        )

model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))

# optimizer
optimizer = model.configure_optimizers(
    weight_decay, learning_rate, (beta1, beta2), device_type
)
if init_from == "resume":
    optimizer.load_state_dict(checkpoint["optimizer"])
checkpoint = None  # free up memory

# compile the model
if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model)  # requires PyTorch 2.0

# wrap model into DDP container
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])

assert split in ["train", "val", "test"]
data = (
    test_data if split == "test" else val_data if split == "val" else train_data
)
max_length = model.config.block_size
seq_len = len(data)
model.eval()
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
    end_loc = min(begin_loc + max_length, seq_len - 1)
    trg_len = (
        end_loc - prev_end_loc
    )  # may be different from stride on last loop
    input_ids = (
        torch.from_numpy((data[begin_loc:end_loc].astype(np.int64)))
        .unsqueeze(0)
        .to(device)
    )
    target_ids = (
        torch.from_numpy((data[begin_loc + 1 : end_loc + 1].astype(np.int64)))
        .unsqueeze(0)
        .to(device)
    )
    target_ids[:, :-trg_len] = -1

    with torch.no_grad():
        outputs, loss = model(input_ids, target_ids)
        # print(loss)
        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = loss

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    if end_loc == seq_len - 1:
        break

ppl = torch.exp(torch.stack(nlls).mean()).item()

print(f"Perplexity: {ppl:.2f}")

# destroy DDP process group
if ddp:
    destroy_process_group()
