"""
Code for measuring loss surface sharpness of TinyLlama models.

Command example:
    find "$ROOT_PATH/e8m3" -name '*.pth' | \
    xargs -I CKPT python -m fire ../../landscape/measure_sharpness.py main \
        --model_name tiny_LLaMA_120M \
        --checkpoint_path CKPT \
        --eval_iters 32 \
        --sig_bits 3 \
        --num_processes 4 \
    >> tinyllama_120m_e8m3_out.log

"""
from pathlib import Path

import fire
import numpy as np
import torch
from torch import Tensor
# import torch.multiprocessing as mp
from joblib import Parallel, delayed  # Using joblib because multiprocessing does not work.
from lightning import Fabric
from torch.utils.data import DataLoader
import torch.nn.functional as F
from beartype import beartype
from loguru import logger
from scipy.optimize import minimize, Bounds

from third_party.TinyLlama.lit_gpt.model import GPT
from third_party.TinyLlama.lit_gpt.config import Config
from custom.layer import MaskedLinearBF16
from custom.packed_dataset import PackedDataset


def fun(x: np.ndarray, *args) -> float:
    _last_logit, _last_label = args
    _last_input = _last_logit + torch.from_numpy(x)
    assert _last_input.dtype == torch.float64, "Must be FP64 for L-BFGS-B."
    return -F.cross_entropy(input=_last_input, target=_last_label).item()


@beartype
def measure(
        last_logit: Tensor,
        last_label: Tensor,
        use_random_init: bool,
        max_lbfgsb_iters: int | None,
        checkpoint_path: str,
        epsilon: float,
        k: int,
):
    # Measuring sharpness
    c = epsilon * (last_logit.abs() + 1).numpy().astype(np.float64)
    last_loss = F.cross_entropy(input=last_logit, target=last_label).item()
    if use_random_init:
        x0 = (2 * np.random.rand(last_logit.size(-1)) - 1) * c
    else:
        x0 = np.zeros(last_logit.size(-1))

    result = minimize(  # This part is very compute-intensive and on CPU.
        fun=fun,  # Intel MKL-optimized scipy may be useful here.
        x0=x0,
        args=(last_logit, last_label),
        method="L-BFGS-B",
        bounds=Bounds(lb=-c, ub=c),
        options=None if max_lbfgsb_iters is None else {"maxiter": max_lbfgsb_iters},
    )
    last_sharp = (-result.fun - last_loss) / (1 + last_loss)
    out = (f"{checkpoint_path}, {epsilon=}, {k=:03d}, peak: {-result.fun:5.2f}, "
           f"loss: {last_loss:5.2f}, sharpness: {last_sharp * 100:5.2f}, success: {result.success}.")
    # print(out)
    return out, last_sharp


@beartype
@torch.inference_mode()
def main(
        model_name: str,
        checkpoint_path: str,
        gpu: list[int] | None = None,
        eval_iters: int | None = None,
        fp32_layers: int | None = None,
        exp_bits: int = 8,
        sig_bits: int = 7,
        epsilons: tuple[float, ...] = (5e-4, ),
        use_random_init: bool = False,  # Whether l-bfgs-b should have a random initial guess.
        max_lbfgsb_iters: int | None = None,  # Default is 10 in the paper.
        val_data_path: str = "TinyLlama/processed/slim_star_combined",
        num_processes: int | None = None,
):
    logger.info("Getting started.")
    torch.set_grad_enabled(False)  # No gradient calculations.
    torch.set_float32_matmul_precision("high")
    gpu = [0] if gpu is None else gpu
    fabric = Fabric(devices=gpu, accelerator="gpu", precision="bf16-mixed")
    config = Config.from_name(model_name)

    with fabric.init_module(empty_init=False):
        model = GPT(config)

    model.eval()
    if exp_bits < 8 or sig_bits < 7:
        if fp32_layers is None:
            model = MaskedLinearBF16.mask_linear_layers(
                module=model,
                exp_bits=exp_bits,
                sig_bits=sig_bits,
            )
        else:
            assert isinstance(fp32_layers, int) and fp32_layers >= 0
            for block in model.transformer.h[fp32_layers:]:
                MaskedLinearBF16.mask_linear_layers(
                    module=block,
                    exp_bits=exp_bits,
                    sig_bits=sig_bits,
                )
    model = fabric.setup(model)
    logger.info("Finished initializing the model.")

    filenames = sorted(str(p) for p in Path(val_data_path).glob("validation*"))

    dataset = PackedDataset(
        filenames,
        # n_chunks control the buffer size.
        # Note that the buffer size also impacts the random shuffle
        # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer)
        n_chunks=8,
        block_size=model.config.block_size + 1,  # See `TinyLlama/lit_gpt/config.py` for block sizes.
        shuffle=False,
    )
    loader = DataLoader(
        dataset=dataset,
        shuffle=False,
        pin_memory=True,
        batch_size=1,  # This is important.
    )

    logger.info("Loading checkpoint {}.", checkpoint_path)
    sd = torch.load(checkpoint_path, map_location="cpu", mmap=True)["model"]
    sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
    model.load_state_dict(state_dict=sd, strict=True)
    logger.info("Loaded checkpoint {}.", checkpoint_path)

    # with mp.Pool(processes=num_processes) as pool:
    with Parallel(n_jobs=num_processes, return_as="generator") as parallel:
        for epsilon in epsilons:
            kwargs_inputs = list()
            for k, val_data in enumerate(loader, start=1):  # How to run this in parallel.
                if (eval_iters is not None) and (k > eval_iters):
                    break
                assert isinstance(val_data, Tensor)
                val_data = val_data.to(fabric.device, non_blocking=True)

                inputs = val_data[:, 0:model.config.block_size]
                labels = val_data[:, 1:model.config.block_size + 1]
                logits = model(inputs)
                last_label = labels[0, -1].cpu()
                last_logit = logits[0, -1, ...].cpu()

                kwargs_inputs.append(
                    dict(
                        last_logit=last_logit,
                        last_label=last_label,
                        use_random_init=use_random_init,
                        max_lbfgsb_iters=max_lbfgsb_iters,
                        checkpoint_path=checkpoint_path,
                        epsilon=epsilon,
                        k=k,
                    )
                )
            last_sharps = list()
            for out, last_sharp in parallel(delayed(measure)(**kwargs) for kwargs in kwargs_inputs):
                print(out)
                last_sharps.append(last_sharp)
            logger.info("Average sharpness: {:5.2f}", sum(last_sharps) / len(last_sharps) * 100)



if __name__ == '__main__':
    fire.Fire(main)
