import logging
import os
import sys
import time

import lm_eval
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import TaskManager
from lm_eval.utils import make_table
import torch
from transformers import AutoTokenizer

try:
    import wandb
    has_wandb: bool = True
except ModuleNotFoundError:
    has_wandb: bool = False

from data_utils import get_data, set_seed
from gptq import get_pre_trained_model, quantize_model, evaluate_model
from parse_args import parse_args

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.set_float32_matmul_precision('highest')


def main() -> None:
    logging.basicConfig(format='%(levelname)s %(asctime)s %(message)s', level=logging.INFO)
    logging.info(' '.join(sys.argv))
    args = parse_args()
    logging.info(args)
    if args.wandb:
        assert has_wandb, "`wandb` not installed, try pip install `wandb`"
        args.exp_name = (
            os.environ.get("WANDB_NAME", "babai_v2")
            + f"_model_{os.path.basename(args.model_dir)}"
            + f"_wbits_{args.quant_bit_width}"
            + f"_groupsize_{args.quant_group_size}"
            + f"_quant_order_{args.quant_order}"
            + f"_do_rtn_{args.do_rtn}"
            + f"_entropy_{args.quant_use_entropy_mode}"
            + f"_clip_{args.quant_do_clip}"
            + f"_mse_{args.quant_use_mse}"
            + f"_seqlen_{args.seqlen}"
            + f"_train_samples_{args.data_train_n_samples}"
            + f"_do_quant_{args.do_quant}"
            + f"_batch_{args.batch_size}"
            + f"_seed_{args.data_seed}"
            + f"_lm_eval_batch_{args.lm_eval_batch_size}"
            + f"_evaltasks_{'_'.join(args.lm_eval_tasks)}"
            + f"_save_{args.save_model}"
        )
        wandb.init(
            config={a: getattr(args, a) for a in dir(args) if not a.startswith("_")},
        )
        wandb.run.log_code(".")

    model = get_pre_trained_model(args.model_dir)
    model.eval()
    device: torch.device = torch.device('cuda')

    tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=False)

    torch.manual_seed(args.data_seed)
    set_seed(args.data_seed)

    if args.do_quant:
        encodings_train = get_data(dataset_name="fineweb-edu", tokenizer=tokenizer,
                               max_sequence_length=args.seqlen, num_calibration_samples=args.data_train_n_samples,
                               seed=args.data_seed, eval_mode=False)
        encodings_train = encodings_train.to(device)
        tick = time.time()
        results = quantize_model(
            model=model,
            encodings=encodings_train,
            device=device,
            quant_group_size=args.quant_group_size,
            quant_bit_width=args.quant_bit_width,
            quant_order=args.quant_order,
            quant_use_entropy_mode=args.quant_use_entropy_mode,
            quant_do_clip=args.quant_do_clip,
            quant_use_mse=args.quant_use_mse,
            batch_size=args.batch_size,
            save_gpu_mem_level=args.save_gpu_mem_level,
            do_rtn=args.do_rtn,
            outlier_percentage=args.outlier_percentage,
        )
        logging.info(f'finished quantizing in {time.time() - tick:.2f} s')

        if args.save_model:
            # Save model and tokenizer
            save_model_name = f'{args.model_dir.split("/")[-1]}-GPTQ-{args.quant_bit_width}b-'
            if args.quant_use_entropy_mode.lower() in [None, '', 'none']:
                save_model_name += f'{args.quant_group_size}g-{args.quant_order}_order-{"mse" if args.quant_use_mse else "absmax"}_scale-{"" if args.quant_do_clip else "no_"}clip'
            elif args.quant_use_entropy_mode.lower() in ['grouped_e', 'grouped_f']:
                save_model_name += f'{args.quant_group_size}g-{args.quant_order}_order-entropy_{args.quant_use_entropy_mode}_scale'
            else:
                save_model_name += f'{args.quant_order}_order-entropy_{args.quant_use_entropy_mode}_scale'
            save_model_name += f'-rtn' if args.do_rtn else ''
            save_model_name += f'-outlier{args.outlier_percentage}' if args.outlier_percentage is not None else ''
            output_dir = os.path.join(args.output_base, save_model_name)
            logging.info(f'saving model to {output_dir}')
            model.save_pretrained(output_dir)
            tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=args.model_dir, use_fast=False)
            tokenizer.save_pretrained(output_dir)
            torch.save(results, os.path.join(output_dir, f'{save_model_name}.pt'))

    dataset_names = ['wikitext2', 'c4']
    for dataset_name in dataset_names:
        encodings = get_data(dataset_name=dataset_name, tokenizer=tokenizer,
                               max_sequence_length=args.seqlen, num_calibration_samples=args.data_train_n_samples,
                               seed=args.data_seed, eval_mode=True)
        logging.info(f'evaluating {dataset_name}')
        ppl = evaluate_model(model=model, encodings=encodings, device=device, batch_size=args.batch_size, save_gpu_mem=args.save_gpu_mem_level >= 3)
        logging.info(f'{dataset_name} = ppl: {ppl.item():.4f}')
        if args.wandb:
            wandb.log({dataset_name: ppl.item()})
    if args.eval_openllm:
        results = {}
        lm = HFLM(
            pretrained=model.to(device=device),
            tokenizer=tokenizer, 
            batch_size=args.lm_eval_batch_size,
            max_length=4096, # from open LLM openllm
            add_bos_token=False
        )
        task_manager = TaskManager()

        # Hellaswag (10-shot)
        if "hellaswag" in args.lm_eval_tasks:
            task_results = lm_eval.simple_evaluate(
                model=lm,
                tasks="hellaswag",
                num_fewshot=10,
                batch_size=args.lm_eval_batch_size,
                task_manager=task_manager,
            )["results"]
            results.update(task_results)
            print(make_table({"results": task_results, "versions": {}, "n-shot": {}, "higher_is_better": {}}))

        # TruthfulQA (0-shot)
        if "truthfulqa" in args.lm_eval_tasks:
            task_results = lm_eval.simple_evaluate(
                model=lm,
                tasks="truthfulqa",
                num_fewshot=0,
                batch_size=args.lm_eval_batch_size,
                task_manager=task_manager,
            )["results"]
            results.update(task_results)
            print(make_table({"results": task_results, "versions": {}, "n-shot": {}, "higher_is_better": {}}))

        # Winogrande (5-shot)
        if "winogrande" in args.lm_eval_tasks:
            task_results = lm_eval.simple_evaluate(
                model=lm,
                tasks="winogrande",
                num_fewshot=5,
                batch_size=args.lm_eval_batch_size,
                task_manager=task_manager,
            )["results"]
            results.update(task_results)
            print(make_table({"results": task_results, "versions": {}, "n-shot": {}, "higher_is_better": {}}))

        # MMLU (0-shot)
        if "mmlu" in args.lm_eval_tasks:
            task_results = lm_eval.simple_evaluate(
                model=lm,
                tasks="mmlu",
                batch_size=args.lm_eval_batch_size,
                apply_chat_template=False,
                fewshot_as_multiturn=True,
                task_manager=task_manager,
            )["results"]
            results.update(task_results)
            print(make_table({"results": task_results, "versions": {}, "n-shot": {}, "higher_better": {}}))


if __name__ == '__main__':
    main()
