import os,utils,lm_eval,logging
import quant._get_args as get_args
import quant.btc_model_utils as btc_model_utils,utils.data_utils as data_utils,utils.eval_utils as eval_utils,utils.gptq_utils as gptq_utils
import torch,torch.nn as nn,torch.nn.functional as F,transformers,sys,datetime,datasets,lm_eval,lm_eval.tasks,accelerate
from loguru import logger
from quant.btc_train import rotate_smooth_train
from lm_eval.tasks import TaskManager
from lm_eval.utils import make_table
from lm_eval.models.huggingface import HFLM,eval_logger
eval_logger.level = logging.ERROR

def main(args):
    transformers.set_seed(args.seed)
    lm = btc_model_utils.LM(args)
    lm.model.eval()
    test_loader = data_utils.get_loaders(args.eval_dataset,seed=args.seed,model=args.model,seqlen=lm.seqlen,eval_mode=True)  
    if args.pre_eval:
        pre_ppl = eval_utils.evaluator(lm.model,test_loader,utils.DEV,args,),
        logger.info(f"Float ppl:{pre_ppl}")
        eval_tasks(lm,args)
    if args.rotate:
        lm.fuse_layer_norms()
        lm.generate_rotate_parameters()
        
        if False: 
            lm.rotate_smooth_model_inplace()
            pre_ppl = eval_utils.evaluator(lm.model,test_loader,utils.DEV,args)
            logger.info(f"Fuse Layer Norms PPL:{pre_ppl}")
        if args.train_rotate:
            lm.set_quant_state(use_weight_quant=args.train_enable_wquant,use_act_quant=True,use_fully_quant=args.fully_quant)
            lm = rotate_smooth_train(args,lm)
        elif args.resume_path is not None:
            q = torch.load(args.resume_path)
            r=lm.model.load_state_dict(q, strict=False)
            logger.info(f"resume from {args.resume_path}")
        lm.rotate_smooth_model_inplace()
        if False: 
            dynamic_ppl = dynamic_eval(lm,test_loader,args,use_act_quant=True,use_fully_quant=args.fully_quant,use_weight_quant=args.train_enable_wquant)
            logger.info(f"After Dynamic ActQuant PPL:{dynamic_ppl}")
            if args.test_static:
                eval_tasks(lm,args)
                static_ppl = static_eval(lm,test_loader,args)
                logger.info(f"After Static ActQuant PPL:{static_ppl}")

    
    trainloader = data_utils.get_loaders(
        args.cal_dataset, nsamples=args.nsamples,
        seed=args.seed, model=args.model,
        seqlen=lm.seqlen, eval_mode=False
    )
    gptq_utils.quant_sequential(lm.model, trainloader, utils.DEV, args)

    if True:
        dynamic_ppl = dynamic_eval(lm,test_loader,args,use_act_quant=True,use_fully_quant=args.fully_quant)
        logger.info(f"After Dynamic Act&Weight Quant PPL:{dynamic_ppl}")
        if args.test_static:
            static_ppl = static_eval(lm,test_loader,args)
            logger.info(f"After Static Act&Weight Quant PPL:{static_ppl}")

    if not args.lm_eval:
        return
    else:
        eval_tasks(lm,args)


def eval_tasks(lm,args):
    utils.cleanup_memory()
    if args.distribute:
        utils.distribute_model(lm.model)
    else:
        lm.model.to(utils.DEV)
    task_manager = TaskManager()
    tasks = task_manager.match_tasks(args.tasks)
    hflm = HFLM(pretrained=lm.model, tokenizer=lm.tokenizer,batch_size="auto",max_batch_size=256)
    results = lm_eval.simple_evaluate(hflm,tasks=tasks,batch_size="auto",max_batch_size=256)
    metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results['results'].items()}
    mean_acc_val = round(sum(metric_vals.values()) / len(metric_vals.values()), 4)
    std_vals = {task: round(result.get('acc_norm_stderr,none', result['acc_stderr,none']), 4) for task, result in results['results'].items()}
    mean_std_val =round(sum(std_vals.values()) / len(std_vals.values()), 4) 
    metric_vals['acc_avg'] = mean_acc_val
    results['results']['AVERAGE'] = {
        "acc,none":mean_acc_val,
        "acc_stderr,none":mean_std_val
    }
    logger.info("\n" + make_table(results))




def dynamic_eval(lm:btc_model_utils.LM,dataloader,args,use_act_quant=True,use_fully_quant=False,use_weight_quant=False):
    lm.set_quant_state(use_weight_quant=use_weight_quant,use_act_quant=use_act_quant,use_fully_quant=use_fully_quant)
    lm.set_dynamic(True)
    ppl = eval_utils.evaluator(lm.model,dataloader,utils.DEV,args)
    return ppl

def static_eval(lm:btc_model_utils.LM,dataloader,args):
    lm.set_quant_state(False,True,args.fully_quant)
    lm.set_dynamic(False)
    if not lm.calibrated:
        lm.set_observer(True)
        eval_utils.evaluator(lm.model,dataloader,args.dev,args,eval_samples=2) 
    lm.set_observer(False) 
    ppl = eval_utils.evaluator(lm.model,dataloader,args.dev,args)
    return ppl

if __name__ == "__main__":
    args = get_args.parse_args()
    main(args)

