import argparse
import pathlib
import time

import torch
from tqdm import tqdm
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    StaticCache,
)

from test import M8aLinear


def load_checkpoint(
        pretrained_model_path: str,
        do_convert: int = 0,
        dtype: torch.dtype = torch.bfloat16,
        device=torch.device('cuda'),
        do_torch_script: bool = False,
        do_torch_compile: bool = True,
):
    config = AutoConfig.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_path,
        torchscript=do_torch_script,
        return_dict=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_path,
        use_fast=False,
        torchscript=do_torch_script,
        config=config,
    )
    tokenizer.pad_token = tokenizer.eos_token
    model: torch.nn.Module = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_path,
        trust_remote_code=True,
        torch_dtype=dtype,
        device_map='cpu',
        config=config,
    )

    if do_convert in [1, 2]:
        records: dict[str, dict[str, dict]] = torch.load(next(pathlib.Path(pretrained_model_path).rglob('*.pt')), map_location=torch.device('cpu'), weights_only=True)
        for key, value in tqdm(records['data'].items()):
            m8a_layer: M8aLinear = M8aLinear(
                bit_width=int(value['qzero'].flatten()[0].log2().round()) + 1,
                group_size=value['qweight'].size(-1) // value['scale'].size(-1),
                weight_int=value['qweight'].to(torch.int8),
                weight_scale=value['scale'],
                weight_sparse=value['oweight'] if do_convert == 1 else None,
            )
            model.set_submodule(key, m8a_layer)

    model.to(dtype=dtype, device=device).eval()

    if do_torch_script:
        model: torch.nn.Module = torch.jit.script(model)

    if do_torch_compile:
        model.forward = torch.compile(model.forward, fullgraph=True, dynamic=True, mode='reduce-overhead')

    print(f'Model loaded')
    return model, tokenizer


def e2e_benchmark(
        model: torch.nn.Module,
        n_prefill_tokens: int,
        n_decode_tokens: int,
        n_repeats: int = 1,
        dtype: torch.dtype = torch.bfloat16,
        device: torch.device = torch.device('cuda'),
        tokenizer = None,
        input_str: str = None,
) -> float:
    if tokenizer is not None:
        input_ids: torch.Tensor = tokenizer(input_str, return_tensors='pt').input_ids.to(device=device)[..., :n_prefill_tokens]
        assert input_ids.size(-1) == n_prefill_tokens
    else:
        input_ids: torch.Tensor = torch.empty(1, n_prefill_tokens, dtype=torch.int64, device=device)

    cache_position: torch.Tensor = torch.arange(n_prefill_tokens, dtype=torch.int64, device=device)

    if tokenizer is not None:
        generated_ids: torch.Tensor = torch.zeros(1, n_prefill_tokens + n_decode_tokens + 1, dtype=torch.int64, device=device)
        generated_ids[..., :n_prefill_tokens] = input_ids

    past_key_values: StaticCache = StaticCache(
        config=model.config,
        max_batch_size=1,
        max_cache_len=n_prefill_tokens + n_decode_tokens + 1,
        device=device,
        dtype=dtype,
    )

    logits: torch.Tensor = model(
        input_ids=input_ids,
        cache_position=cache_position,
        past_key_values=past_key_values,
        return_dict=False,
        use_cache=False,
    )[0]
    next_token: torch.Tensor = logits[..., -1:, :].argmax(dim=-1)
    if tokenizer is not None:
        generated_ids[..., n_prefill_tokens:n_prefill_tokens + 1] = next_token
        print(tokenizer.decode(generated_ids[0]))

    tpot_list_list: list[list[float]] = []  # time per output token
    with torch.inference_mode():
        for _ in range(n_repeats):
            seq_len: int = n_prefill_tokens + 1
            tpot_list: list[float] = []
            for _ in range(n_decode_tokens):
                cache_position: torch.Tensor = torch.tensor([seq_len], dtype=torch.int64, device=device)
                torch.cuda.synchronize()
                start_time: float = time.perf_counter()
                logits: torch.Tensor = model(
                    input_ids=next_token,
                    position_ids=None,
                    cache_position=cache_position,
                    past_key_values=past_key_values,
                    return_dict=False,
                    use_cache=False,
                )[0]
                next_token: torch.Tensor = logits[..., -1:, :].argmax(dim=-1)
                torch.cuda.synchronize()
                end_time: float = time.perf_counter()
                if tokenizer is not None:
                    generated_ids[..., seq_len:seq_len+1] = next_token
                seq_len += 1
                tpot_list.append(end_time - start_time)
        tpot_list_list.append(tpot_list)

    if tokenizer is not None:
        print(tokenizer.decode(generated_ids[0]))

    print(tpot_list)
    tpot: float = torch.as_tensor(tpot_list, dtype=torch.float64, device=device).quantile(q=.5, dim=-1).quantile(q=.5).item()
    print(f'e2e benchmark tpot: {tpot} sec')
    return tpot


def run_e2e_benchmark(
        ckpt_path: str,
        do_convert: int,
        do_torch_compile: bool = True,
        n_prefill_tokens: int = 1,
        n_decode_tokens: int = 128,
        n_repeats: int = 1,
) -> None:
    dtype: torch.dtype = torch.bfloat16
    device: torch.device = torch.device('cuda')
    model, tokenizer = load_checkpoint(
        pretrained_model_path=ckpt_path,
        do_convert=do_convert,
        dtype=dtype,
        device=device,
        do_torch_script=False,
        do_torch_compile=do_torch_compile,
    )
    tpot: float = e2e_benchmark(
        model=model,
        n_prefill_tokens=n_prefill_tokens,
        n_decode_tokens=n_decode_tokens,
        n_repeats=n_repeats,
        dtype=dtype,
        device=device,
        # tokenizer=tokenizer,
        # input_str='Hi',
    )


def main() -> None:
    parser: argparse.ArgumentParser = argparse.ArgumentParser(add_help=True)
    parser.add_argument(
        '--ckpt',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--do-convert',
        type=int,
        choices=[0, 1, 2],  # 0: full precision weights, 1: low-bit weights + full precision outliers, 2: low-bit weights only
        default=0,
    )
    parser.add_argument(
        '--n-repeats',
        type=int,
        default=1,
    )
    args: argparse.Namespace = parser.parse_args()
    print(args)

    run_e2e_benchmark(args.ckpt, do_convert=args.do_convert, n_repeats=args.n_repeats)


if __name__ == '__main__':
    main()
