# AIHA 2024
# Reduced-precision LLM inference framework with various precision
# Quantization framework is based on MX-format (https://github.com/microsoft/microxcaling)

import os
import torch
from datetime import datetime
import argparse
import numpy as np
import transformers
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant import smooth
from utils.common import *
from utils.mx import *
import lm_eval
from lm_eval import evaluator
import warnings
import logging
import pprint
from accelerate import Accelerator
from rotation import hadamard_utils, model_utils, rotation_utils
import math
warnings.filterwarnings('ignore')


def main(args):
    # ============================ Load model
    # QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs (arXiv:2404.00456)
    if args.rotate:
        if args.rotate_mode == 'group_hadamard':
            setattr(args, 'had_dim', args.block_size_linear)  # args.block_size_linear
        logging.info(
            f"Applying {args.rotate_mode} transform dim {args.had_dim if args.rotate_mode == 'group_hadamard' else '-1'} ...")
        kwargs = {'device_map': 'cpu', 'trust_remote_code': False,
                  'attn_implementation': "eager"}
        model = AutoModelForCausalLM.from_pretrained(
            args.model, torch_dtype='auto' if args.auto_dtype else torch.bfloat16, **kwargs)
        if model.config.tie_word_embeddings:  # 断开权重共享 针对 Llama-3.2
            logging.info("Tying word embeddings is not supported for rotation, disabling it.")
            separate_embeddings_and_lm_head(model)
        if model_utils.model_type_extractor(model) == model_utils.PHI3_MODEL:
            instead_phi3_forward(model)
        if args.pre_smooth:
            pre_smooth_scales = torch.load(args.pre_smooth, weights_only=False)
            logging.info(f"Load pre-smooth scales from {args.pre_smooth}")
            smooth.pre_smooth_lm(model, pre_smooth_scales, args.pre_smooth_alpha)
        rotation_utils.fuse_layer_norms(model)
        rotation_utils.rotate_model(model, args)
        rotation_utils.cleanup_memory(verbos=True)
    else:
        kwargs = {'device_map': 'cpu', 'trust_remote_code': False,
                  'attn_implementation': "eager"}
        model = AutoModelForCausalLM.from_pretrained(
            args.model, torch_dtype='auto' if args.auto_dtype else torch.bfloat16, **kwargs)
        if model.config.tie_word_embeddings:  # 断开权重共享 针对 Llama-3.2
            logging.info("Tying word embeddings is not supported for rotation, disabling it.")
            separate_embeddings_and_lm_head(model)
        if model_utils.model_type_extractor(model) == model_utils.PHI3_MODEL:
            instead_phi3_forward(model)

    if args.gptq:
        if args.gptq_load_path:
            logging.info("Load quantized model from: {}.".format(args.gptq_load_path))
            load_model_in_parts(model, args.gptq_load_path)
        else:
            from gptq import gptq_utils
            from utils import calib
            logging.info("Quantizing model with GPTQ.")
            trainloader = calib.get_loaders(args.gptq_cal_dataset, nsamples=args.gptq_cal_nsamples,
                                            seqlen=args.gptq_cal_seqlen, model=args.model, eval_mode=False)
            gptq_utils.gptq_fwrd(model, trainloader, 'cuda:0', args)

        if args.gptq_save_path:
            folder_name = f'{model.model_name}'
            folder_name += f"alpha_{args.pre_smooth_alpha}" if args.pre_smooth else ""
            args.save_gptq = os.path.join(args.gptq_save_path, folder_name)
            if not os.path.exists(args.gptq_save_path):
                os.makedirs(args.gptq_save_path)
            logging.info("Save quantized model to: {}.".format(args.gptq_save_path))
            save_model_in_parts(model, args.gptq_save_path, prefix=f'{model.model_name}_part')

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        args.model, trust_remote_code=False, use_fast=False)

    # ============================ MX format
    mx_specs_linear = parse_mx_specs(args, 'linear')
    mx_specs_matmul = parse_mx_specs(args, 'matmul')
    mx_specs_ln = parse_mx_specs(args, 'ln')
    mx_specs_head = parse_mx_specs(args, 'head')
    get_mx_model(
        model.eval(),
        mx_specs_linear=mx_specs_linear,
        mx_specs_matmul=mx_specs_matmul,
        mx_specs_ln=mx_specs_ln,
        mx_specs_head=mx_specs_head,
        args=args,
    )

    # ============================ Runtime Hadamard Transform for QuaRot
    if False:  # args.rotate:
        model_type = model_utils.model_type_extractor(model)
        if model_type in [model_utils.LLAMA_MODEL, model_utils.MISTRAL_MODEL,
                          model_utils.QWEN2_MODEL, model_utils.PHI3_MODEL]:
            for name, module in model.named_modules():
                if 'down_proj' in name and args.rotate_mode != 'identity':
                    if args.rotate_mode == 'hadamard':
                        had_K, K = hadamard_utils.get_hadK(
                            model.config.intermediate_size)
                        setattr(module, "online_full_had", True)
                    elif args.rotate_mode == 'group_hadamard':
                        had_K, K = hadamard_utils.get_hadK(
                            args.had_dim)
                        setattr(module, "online_group_had", True)
                        setattr(module, "had_dim", args.had_dim)
                    setattr(module, "had_K", had_K)
                    setattr(module, "K", K)
                if 'o_proj' in name and args.online_partial_had:
                    had_K, K = hadamard_utils.get_hadK(
                        model.config.num_attention_heads)
                    setattr(module, "online_partial_had", True)
                    setattr(module, "had_K", had_K)
                    setattr(module, "K", K)
                    setattr(module, "had_dim", model.config.hidden_size //
                            model.config.num_attention_heads)
        elif model_type == model_utils.MIXTRAL_MODEL:
            for name, module in model.named_modules():
                if 'w2' in name and args.rotate_mode != 'identity':
                    if args.rotate_mode == 'hadamard':
                        had_K, K = hadamard_utils.get_hadK(
                            model.config.intermediate_size)
                        setattr(module, "online_full_had", True)
                    elif args.rotate_mode == 'group_hadamard':
                        had_K, K = hadamard_utils.get_hadK(
                            args.had_dim)
                        setattr(module, "online_group_had", True)
                        setattr(module, "had_dim", args.had_dim)
                    setattr(module, "had_K", had_K)
                    setattr(module, "K", K)
                if 'o_proj' in name and args.online_partial_had:
                    had_K, K = hadamard_utils.get_hadK(
                        model.config.num_attention_heads)
                    setattr(module, "online_partial_had", True)
                    setattr(module, "had_K", had_K)
                    setattr(module, "K", K)
                    setattr(module, "had_dim", model.config.hidden_size //
                            model.config.num_attention_heads)
        else:
            raise NotImplementedError

    # KV Cache
    if args.rotate_kv:
        rope_function_name = model_utils.get_rope_function_name(model)
        layers = model_utils.get_layers(model)
        had_dim = args.block_size_matmul if args.group_rotate_kv else -1
        for layer in layers:
            rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
                layer.self_attn,
                rope_function_name,
                config=model.config,
                had_dim=had_dim,
            )
    return model, tokenizer


def evaluate_model(model, tokenizer, args):
    # Load into GPU
    if model.device.type == 'cpu':
        # accelerator = Accelerator()
        # model = accelerator.prepare(model)
        distribute_model(model)

    # ============================ Evaluation
    if args.eval_ppl:
        seqlen = 2048  # hard-coding
        args.limit = -1  # whole samples
        if 'Llama-3' in args.model:
            cache_testloader = f'calibset/wikitext_test_{seqlen}_{args.seed}_llama3.cache'
        elif 'Llama-2' in args.model:
            cache_testloader = f'calibset/wikitext_test_{seqlen}_{args.seed}_llama2.cache'
        elif 'mixtral' in args.model.lower():
            cache_testloader = f'calibset/wikitext_test_{seqlen}_{args.seed}_mixtral.cache'
        elif 'mistral' in args.model.lower():
            cache_testloader = f'calibset/wikitext_test_{seqlen}_{args.seed}_mistral.cache'
        elif 'Qwen2' in args.model:
            cache_testloader = f'calibset/wikitext_test_{seqlen}_{args.seed}_qwen2.cache'
        elif 'opt' in args.model:
            cache_testloader = f'calibset/wikitext_test_{seqlen}_{args.seed}_opt.cache'
        elif 'phi-3' in args.model.lower():
            cache_testloader = f'calibset/wikitext_test_{seqlen}_{args.seed}_phi3.cache'
        else:
            raise NotImplementedError(
                f"Model {args.model} is not supported for wikitext2 evaluation.")
        if os.path.exists(cache_testloader):
            testloader = torch.load(cache_testloader, weights_only=False)
            logging.info(f"load calibration from {cache_testloader}")
        else:
            from utils.calib import get_wikitext2_test
            testloader = get_wikitext2_test(
                seed=args.seed, seqlen=seqlen, model=args.model)
            if not os.path.exists('calibset'):
                os.mkdir('calibset')
            torch.save(testloader, cache_testloader)
        testenc = testloader.input_ids
        nsamples = testenc.numel() // seqlen
        use_cache = model.config.use_cache
        model.config.use_cache = False
        model.eval()
        nlls = []
        with torch.no_grad():
            pbar = tqdm(range(nsamples))
            for i in pbar:
                batch = testenc[:, (i * seqlen): ((i + 1) * seqlen)].to('cuda')
                if "opt" in args.model.lower():
                    outputs = model.model.decoder(batch)
                else:
                    outputs = model.model(batch)
                hidden_states = outputs[0]
                logits = model.lm_head(hidden_states)
                shift_logits = logits[:, :-1, :]
                shift_labels = testenc[:, (i * seqlen): ((i + 1) * seqlen)][
                    :, 1:
                ].to(model.lm_head.weight.device)
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1),
                )
                neg_log_likelihood = loss.float()
                if not math.isnan(neg_log_likelihood):
                    nlls.append(neg_log_likelihood)
                # pbar.set_description(f'loss: {loss.item():.4f}')
                ppl = torch.exp(torch.stack(nlls).mean())
                pbar.set_description(f"PPL: {ppl.item():.2f}")
                if i == args.limit:
                    break

            ppl = torch.exp(torch.stack(nlls).mean())
        logging.info(f'wikitext ppl : {ppl.item():.2f}')
        model.config.use_cache = use_cache
        results = {'wiki_ppl': ppl.item()}

    if args.tasks:  # lm-eval
        if '70' in args.model:
            batch_size = 8
        else:
            batch_size = "auto"
        lm = lm_eval.models.huggingface.HFLM(
            pretrained=model,
            tokenizer=tokenizer,
            backend='causal',
            trust_remote_code=True,
            batch_size=batch_size,
        )
        with torch.no_grad():
            results = evaluator.simple_evaluate(
                model=lm,
                tasks=args.tasks,
                num_fewshot=args.num_fewshot,
                batch_size=batch_size
            )
        results = results['results']
        logging.info(pprint.pformat(results))
        metric_vals = {task: round(result.get(
            'acc_norm,none', result['acc,none']), 4) for task, result in results.items()}
        acc_avg = calculate_avg_accuracy(args.tasks, results)
        metric_vals['average'] = round(acc_avg, 4)

        for task, result in metric_vals.items():
            logging.info(f'after reconstruction {task} acc: {result * 1e2 :.2f}')


def calculate_avg_accuracy(task_names: list, results: dict) -> float:
    from lm_eval.tasks import get_task_dict
    print(task_names)
    n_tasks = len(task_names)
    acc_cumul = sum(
        result.get('acc_norm,none', result['acc,none']) for task, result in results.items() if 'mmlu' not in task
    )

    questions_per_mmlu_task = {
        task_name: get_task_dict([task_name])[task_name].dataset["test"].num_rows
        for task_name in task_names
        if 'mmlu' in task_name
    }

    if not questions_per_mmlu_task:
        return acc_cumul / n_tasks

    # Calculate average accuracy for mmlu tasks, weighted by number of questions in each task
    acc_mmlu = sum(
        result.get('acc_norm,none', result['acc,none']) * questions_per_mmlu_task[task]
        for task, result in results.items()
        if 'mmlu' in task
    )
    acc_mmlu_avg = acc_mmlu / sum(questions_per_mmlu_task.values())

    return (acc_cumul + acc_mmlu_avg) / (n_tasks - len(questions_per_mmlu_task) + 1)


def parse_args():
    parser = argparse.ArgumentParser()
    # Model and Datsets
    parser.add_argument('--model', type=str, default=None)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--tasks', type=str2list, default=[])
    parser.add_argument('--num_fewshot', type=str2int, default='none')
    parser.add_argument('--eval_ppl', type=str2bool, default=False)
    # Bit-configuration (Linear)
    parser.add_argument('--w_elem_format_linear', type=str, default='none')
    parser.add_argument('--a_elem_format_linear', type=str, default='none')
    parser.add_argument('--scale_bits_linear', type=int, default=8)
    parser.add_argument('--block_size_linear', type=int, default=32)
    parser.add_argument('--double_quant_linear', type=str2bool, default=False)
    parser.add_argument('--w_dq_only', type=str2bool, default=False,
                        help='Only double quantize weights, not activations')
    # Bit-configuration (MatMul)
    parser.add_argument('--A_elem_format_matmul', type=str, default='none')
    parser.add_argument('--B_elem_format_matmul', type=str, default='none')
    parser.add_argument('--scale_bits_matmul', type=int, default=8)
    parser.add_argument('--block_size_matmul', type=int, default=32)
    parser.add_argument('--double_quant_matmul', type=str2bool, default=False)
    # Bit-configuration (LayerNorm)
    parser.add_argument('--w_elem_format_ln', type=str, default='none')
    parser.add_argument('--a_elem_format_ln', type=str, default='none')
    parser.add_argument('--scale_bits_ln', type=int, default=8)
    parser.add_argument('--block_size_ln', type=int, default=32)
    # Bit-configuration (LM-Head)
    parser.add_argument('--w_elem_format_head', type=str, default='none')
    parser.add_argument('--a_elem_format_head', type=str, default='none')
    parser.add_argument('--scale_bits_head', type=int, default=8)
    parser.add_argument('--block_size_head', type=int, default=32)
    # Others
    parser.add_argument('--auto_dtype', type=str2bool, default=True)
    parser.add_argument('--custom_cuda', type=str2bool, default=False)
    parser.add_argument('--a_scale_mode', type=int, default=0)
    parser.add_argument('--w_scale_mode', type=int, default=0)
    parser.add_argument('--A_scale_mode', type=int, default=0)
    parser.add_argument('--B_scale_mode', type=int, default=0)
    parser.add_argument('--per_tensor', type=str2bool, default=False)
    # Rotation and Hadamard Transform
    parser.add_argument('--rotate', type=str2bool, default=False)
    parser.add_argument('--rotate_mode', type=str, default='hadamard',
                        choices=['hadamard', 'group_hadamard', 'identity'])
    parser.add_argument('--online_partial_had', type=str2bool, default=False)
    # parser.add_argument('--sorting_transform', type=str2path, default=None)
    parser.add_argument('--rotate_kv', type=str2bool, default=False)
    parser.add_argument('--group_rotate_kv', type=str2bool, default=False)
    parser.add_argument('--kv_quant_only', type=str2bool, default=False)
    parser.add_argument('--kv_tokenwise', type=str2bool, default=False)
    # Pre-smooth and Post-smooth
    parser.add_argument('--pre_smooth', type=str2path, default=None,
                        help='Path to the pre-smooth scales file')
    parser.add_argument('--pre_smooth_alpha', type=float, default=0.5,
                        help='Alpha for pre-smooth scales')
    # Smooth Quantization
    parser.add_argument('--smooth_quant', type=str2bool, default=False)
    parser.add_argument('--smooth_quant_alpha', type=float, default=0.85,
                        help='Alpha for smooth quant')
    # GPTQ Quantization
    parser.add_argument('--gptq', type=str2bool, default=False)
    parser.add_argument('--gptq_percdamp', type=float, default=0.01)
    parser.add_argument('--gptq_cal_dataset', type=str, default='wikitext2',
                        choices=['wikitext2', 'ptb', 'c4'])
    parser.add_argument('--gptq_cal_nsamples', type=int, default=2048)
    parser.add_argument('--gptq_cal_seqlen', type=int, default=128)
    parser.add_argument('--gptq_load_path', type=str2path, default=None,
                        help='Path to load quantized model')
    parser.add_argument('--gptq_save_path', type=str2path, default=None,
                        help='Path to save quantized model')
    # OmniQuant
    parser.add_argument("--omni_quant", type=str2bool, default=False)
    parser.add_argument("--aug_loss", type=str2bool, default=False,
                        help="calculate additional loss with same input")
    parser.add_argument("--let", default=False, type=str2bool,
                        help="activate learnable equivalent transformation")
    parser.add_argument("--lwc", default=False, type=str2bool,
                        help="activate learnable weight clipping")
    parser.add_argument("--epochs", type=str2int, default=10)
    parser.add_argument("--let_lr", type=float, default=5e-3)
    parser.add_argument("--lwc_lr", type=float, default=1e-2)
    parser.add_argument("--wd", type=float, default=0)
    parser.add_argument('--omni_cal_dataset', type=str, default='wikitext2',
                        choices=['wikitext2', 'ptb', 'c4'])
    parser.add_argument('--omni_cal_nsamples', type=str2int, default=128)
    parser.add_argument('--omni_cal_seqlen', type=str2int, default=2048)
    parser.add_argument("--omni_save_dir", type=str2path, default='./omniquant/log',
                        help='Path to save OmniQuant model')
    parser.add_argument("--omni_resume", type=str2path, default=None)
    parser.add_argument("--omni_batch_size", type=str2int, default=1)
    # Kertail Quantization
    parser.add_argument('--kurtail', type=str2bool, default=False,
                        help='Enable Kurtail quantization')
    # SpinQuant
    parser.add_argument('--spinquant', type=str2bool, default=False,
                        help='Enable SpinQuant quantization')
    parser.add_argument('--r_path', type=str2path, default=None,
                        help='''Path to the R1 rotation matrix. Deafult is None.
                        If not specified, R1 will generated as "rotate_mode".''')

    args = parser.parse_args()
    if args.kurtail and not args.r_path:
        args.r_path = f'./kurtail/trained_rotation/wikitext2_500samples/{args.model.split("/")[-1]}-r1.pt' \
            if args.rotate_mode == 'hadamard' else f'./kurtail/trained_rotation/wikitext2_500samples/{args.model.split("/")[-1]}-r1_SBRQ.pt'
    elif args.spinquant and not args.r_path:
        raise ValueError("Please specify the --r_path for SpinQuant.")
    set_seed(args.seed)
    # local logging configuration
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    log_dir = f'./log'
    log_dir = os.path.join(log_dir, f'{datetime.now().strftime("%Y-%m-%d")}-{args.model.split("/")[-1]}')
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(
        log_dir, f'{datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.txt')
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(levelname)s %(message)s',
        handlers=[
            logging.FileHandler(log_file, mode='w', encoding='utf-8'),
            logging.StreamHandler()
        ]
    )
    logging.info(pprint.pformat(vars(args)))

    return args


if __name__ == '__main__':
    args = parse_args()
    if args.w_dq_only:
        args.double_quant_linear = True
        args.double_quant_matmul = False
        logging.info("Only double quantize weights per-channel, not activations")
    if args.omni_quant:
        logging.info("Using OmniQuant for quantization")
        from omniquant import omniquant
        lm = omniquant.main_omniquant(args)
        model = lm.model
        tokenizer = lm.tokenizer
    elif args.smooth_quant:
        logging.info("Using SmoothQuant for quantization")
        from smoothquant.smooth import test_smoothquant
        model, tokenizer = test_smoothquant(args)
    else:
        model, tokenizer = main(args)

    evaluate_model(model, tokenizer, args)
