import utils
import torch
import model_utils
import data_utils
import transformers
import quant_utils
import sparsegpt_utils
import duogpt_utils
import eval_utils
import wanda_utils
import os
import transformers


def add_ap(model, args):
    # Add Input Sparsity
    players = quant_utils.find_layers(model, layers=[quant_utils.ActPruneWrapper])
    for name in players:
        players[name].pruner.configure(sparsity=args.a_sparsity, annealing=args.enable_ap_anneal, annealer=args.nsamples-1)


def main():
    args = utils.parser_gen()
    if args.wandb:
        import wandb
        wandb.init(project=args.wandb_project, entity=args.wandb_id, name=args.wandb_name,dir=args.wandb_dir)
        wandb.config.update(args)

    transformers.set_seed(args.seed)
    model = model_utils.get_model(args.model, args.hf_token)
    model.eval()
    quant_utils.add_actprune(model)#! This add act_prune wrappers to all linear layers in the model.

    if args.enable_ap_calibration and args.a_sparsity>0:
        if 'opt' in args.model:
            add_ap(model.model.decoder.layers, args)
        else:
            add_ap(model.model.layers, args)

    if args.load_ckpt:
        load_path = os.path.join(args.load_pmodel_path, args.model, args.wandb_name)
        print("Load prunned model from ", load_path)
        save_dict = torch.load(os.path.join(load_path,"calibrated_model.pt"))
        model.load_state_dict(save_dict["model"], strict=False)
    else:
        if args.sparsity > 0:
            original_seqlen = model.seqlen
            model.seqlen = args.seqlen
            trainloader = data_utils.get_loaders(
                args.cal_dataset, nsamples=args.nsamples,
                seed=args.seed, model=args.model,
                seqlen=args.seqlen, eval_mode=False)
            if args.use_v2:
                _ = duogpt_utils.duogpt_fwrd(model, trainloader, utils.DEV, args)
            elif args.use_wanda:
                wanda_utils.prune_wanda(model, trainloader, utils.DEV, args)
            else:
                _ = sparsegpt_utils.sparsegpt_fwrd(model, trainloader, utils.DEV, args)
    
            model.seqlen = original_seqlen #! change back to the original sequence length

    #! if already turned on the calibration, assuming the inference will also be act sparse.
    if args.a_sparsity>0: #!Turn on this for non-calibration sparsity
        if args.enable_ap_anneal:
            assert (args.enable_ap_calibration), "Error: annealing and calibration act spa must both be enabled"
            args.enable_ap_anneal = False #! We need to turn annealing off for inference
        
        if 'opt' in args.model:
            add_ap(model.model.decoder.layers, args)
        else:
            add_ap(model.model.layers, args)
    
    
    if args.save_ckpt:
        save_dict = {}
        save_path = os.path.join(args.save_pmodel_path, args.model, args.wandb_name)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        save_dict["model"] = model.state_dict()
        torch.save(save_dict, os.path.join(save_path,"calibrated_model.pt"))
        print("Succesfully saved prunned model to ", save_path)
    
    if args.load_ckpt:
        print('Final loaded model check before evaluation: ', model)


    if args.lm_ppl:
        if args.enable_wanda_comparison:
            origin_seqlen = model.seqlen
            model.seqlen = model.config.max_position_embeddings
        
        # if args.distribute:
        #     utils.distribute_model(model)
        # else:
        #     model.to(utils.DEV)

        # Evaluating on dataset
        testloader = data_utils.get_loaders(
            args.eval_dataset,
            # "tatsu-lab/alpaca",
            seed=args.seed,
            model=args.model,
            seqlen=model.seqlen,
            hf_token=args.hf_token,
            eval_mode=True
            )
        dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, args)
        print(dataset_ppl)
        if args.wandb:
            wandb.log({'ppl/{}'.format(args.eval_dataset.upper()): dataset_ppl})

    if not args.lm_eval:
        return
    else:
        import lm_eval
        from lm_eval import utils as lm_eval_utils
        from lm_eval.api.registry import ALL_TASKS
        from lm_eval.models.huggingface import HFLM
        
        if args.enable_wanda_comparison:
            model.seqlen = origin_seqlen #! If we open this before, put it back.

    if args.distribute:
        utils.distribute_model(model)
    else:
        model.to(utils.DEV)

    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, use_fast=False, use_auth_token=args.hf_token, cache_dir=args.hf_cache_path)
    hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=args.lm_eval_batch_size)

    # commenting out this line as it will include two lambda sub-tasks
    # task_names = lm_eval_utils.pattern_match(args.tasks, ALL_TASKS)
    
    task_names = args.tasks
    # if "gsm8k" in task_names:
    #     numshots = 5
    # else:
    #     numshots = None
    results = lm_eval.simple_evaluate(hflm, tasks=task_names, batch_size=args.lm_eval_batch_size)['results']
    # print(results) #! comment out this if a detailed printing statement is nedded

    metric_vals = {task: round((result.get('acc_norm,none', result['acc,none'])), 4) for task, result in results.items()}
    metric_vals['acc_avg'] = round((sum(metric_vals.values()) / len(metric_vals.values())), 4)
    print(metric_vals)

    if args.wandb:
        wandb.log(metric_vals)
    print(args)

if __name__ == '__main__':
    main()
