import argparse
import json
import math
import os
import random

import datasets
import glog
import torch
from tqdm import tqdm

from lib.linear import QuantizedLinear
from lib.utils import gptq_data_utils
from lib.utils.unsafe_import import model_from_hf_path

torch.set_grad_enabled(False)

parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--hf_path', default=None, type=str)
parser.add_argument('--orig_path', default=None, type=str)
parser.add_argument('--tokenizer', default=None, type=str)
parser.add_argument('--seqlen', default=8192, type=int)
parser.add_argument('--manifest', action='store_true')
parser.add_argument('--max_mem_ratio', default=0.7, type=float)


def main(args):
    datasets = ['wikitext2']
    orig_model = model_from_hf_path(args.orig_path, device_map='auto')[0]
    model, model_str = model_from_hf_path(args.hf_path, device_map='auto')
    if args.manifest:
        # manifest the model in BF/FP16 for faster inference
        # useful for non-kernel supported decode modes
        for module in model.modules():
            if isinstance(module, QuantizedLinear):
                module.mode = 'train-fixW'

    for dataset in datasets:
        input_tok = gptq_data_utils.get_test_tokens(
            dataset,
            seed=args.seed,
            seqlen=args.seqlen,
            model=(args.tokenizer
                   if args.tokenizer is not None else model_str))
        nsamples = input_tok.numel() // args.seqlen
        input_tok = input_tok[0, :(args.seqlen * nsamples)].view(
            nsamples, args.seqlen)

        loss_fct = torch.nn.KLDivLoss(reduction='batchmean')
        acc_loss = 0.0
        progress = tqdm(range(nsamples))
        for ii in progress:
            input = input_tok[ii, :].cuda().view(1, -1)
            output = model(input,
                           use_cache=False,
                           output_hidden_states=False,
                           output_attentions=False)[0]
            orig_output = orig_model(input,
                                     use_cache=False,
                                     output_hidden_states=False,
                                     output_attentions=False)[0]
            loss = loss_fct(
                output.reshape(-1, output.shape[-1]).log_softmax(dim=-1),
                orig_output.reshape(-1, orig_output.shape[-1]).softmax(dim=-1))
            acc_loss += loss.item()
            progress.set_description(f"avg_loss = {acc_loss/(ii+1)}")

        avg_loss = acc_loss / nsamples

        glog.info(f'{dataset} KL: {avg_loss}')


if __name__ == '__main__':
    torch.set_grad_enabled(False)
    args = parser.parse_args()
    random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    main(args)
