import logging
import os
import pathlib

import sys
import torch
import copy

from config_files.prune_configs import prune_exper
from general_utils import data_utils, gpu_utils, hf_utils
from pruning_methods.wanda.prune_wanda import prune_wanda
from pruning_methods.sparsegpt.prune_sparsegpt import prune_sparsegpt
from pruning_methods.OATS.prune_oats import prune_oats
from pruning_methods.OATS.prune_oats_compress import prune_oats_compress
from pruning_methods.QR.prune_qr_compress import prune_qr_compress
from pruning_methods.QR.rank_prune import prune_QR_rank
from pruning_methods.OATS.oats_utils import load_oats
from general_utils.config import config
from measure_utils.measure_utils import DumpJSON
from ml_collections import ConfigDict
import lm_eval
import os
from lm_eval.models.huggingface import HFLM
from general_utils.utils import calculate_avg_accuracy
from pruning_methods.pruning_utils import calc_outlier_ratio
# os.environ["WANDB_MODE"] = "offline"
def process_pruning_args(args):
    for arg, argv in vars(args).items():
        logging.debug(f'{arg} = {argv}')

    if not 0 <= args.sparsity <= 1:
        raise ValueError

    if args.device:
        config.device = torch.device(args.device)
    # print('config.device:', config.device)
    if args.dtype == "fp16":
        config.dtype = torch.float16
    elif args.dtype == "bf16":
        config.dtype = torch.bfloat16
    elif args.dtype == "fp32":
        config.dtype = torch.float32
    else:
        raise ValueError


def pruning_main(args, checkpoint_path, results_path) -> None:

    results = DumpJSON(read_path=(results_path+'.json'),
                    write_path=(results_path+'.json'))
    
    results_stats = {
                'target_sparsity': args.sparsity,
                'model'    : args.model,
                'prune_type': args.prune_type,
                }

    print("Target Sparsity:", args.sparsity)

    for k,v in args.prune_hyper.items():
        results_stats[k] = args.prune_hyper[k]
    
    train_dataset = data_utils.get_dataset(args.cal_dataset, train=True)
    print("Finished Train Set", flush=True)

    # Checkpointing for pruning
    if os.path.exists(checkpoint_path + "/prune_chkpt.pt"):
        if args.prune_type == "OATS" and args.prune_hyper['compress']:
            model_adapter, tokenizer = load_oats(
                    args.model,
                    args.sparsity,
                    args.prune_hyper,
                    checkpoint_path,
                    dtype=config.dtype
                )
        else:
            model_adapter, tokenizer = hf_utils.get_model_and_tokenizer(
                args.model,
                checkpoint_path,
                dtype=config.dtype
            )
    else:
         # load one of the pre-trained models
        model_adapter, tokenizer = hf_utils.get_model_and_tokenizer(
            args.model, None, dtype=config.dtype
        )
    # if args.model == "phi-3-mini":
    #     model_path = "/root/.cache/modelscope/hub/models/LLM-Research/Phi-3-mini-4k-instruct"
    # elif args.model == 'llama2-7b':
    #     model_path = "/root/.cache/modelscope/hub/models/shakechen/Llama-2-7b-hf"
    # elif args.model == 'llama2-13b':
    #     model_path = "/root/.cache/modelscope/hub/models/ydyajyA/Llama-2-13b-hf"
    model_path = None
    model_adapter, tokenizer = hf_utils.get_model_and_tokenizer(
        args.model, model_path, dtype=config.dtype
    )

    model = model_adapter.model

    def reset_model_device() -> None:
        if args.distribute_model:
            gpu_utils.distribute_model(model_adapter)
        else:
            model.to(config.device)

    print("Start loading data", flush=True)

    train_loader = data_utils.prepare_dataloader(
        dataset=train_dataset,
        tokenizer=tokenizer,
        max_seqlen=args.cal_max_seqlen,
        batch_size=args.cal_batch_size,
        nsamples=args.cal_nsamples,
        varied_seqlen=args.varied_seqlen,
        seed=args.seed,
    )

    train_loader_prob = data_utils.prepare_dataloader(
        dataset=train_dataset,
        tokenizer=tokenizer,
        max_seqlen=args.cal_max_seqlen,
        batch_size=args.cal_batch_size,
        nsamples=args.cal_nsamples*100,
        varied_seqlen=args.varied_seqlen,
        seed=args.seed,
    )

    print("Finished Calibration Loader", flush=True)
    
    original_param_count = sum(p.numel() for p in model.parameters())
    print(f'Original model parameters: {original_param_count:,d}')

    # ========================= Pruning Code ========================================
    # OWL Sparsity ratios
    if args.use_owl and not os.path.exists(checkpoint_path + "/prune_chkpt.pt"):
        layerwise_sparsity_ratios = calc_outlier_ratio(model_adapter, args.sparsity, train_loader)
    else:
        layerwise_sparsity_ratios = None

    global_rank_ratios = None
    prune_n, prune_m = 0, 0
    if args.prune_hyper['sparsity_type'] != "unstructured":
        assert args.sparsity == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
        prune_n, prune_m = map(int, args.prune_hyper['sparsity_type'].split(":"))
    # reset_model_device()
    if args.prune_type == "wanda":
        prune_wanda(model_adapter, tokenizer, args.sparsity, layerwise_sparsity_ratios, train_loader, checkpoint_path, prune_n=prune_n, prune_m=prune_m)
    elif args.prune_type == "sparse_gpt":
        prune_sparsegpt(model_adapter, tokenizer, args.sparsity, layerwise_sparsity_ratios, train_loader, checkpoint_path, prune_n=prune_n, prune_m=prune_m)
    elif args.prune_type == "OATS":
        if args.prune_hyper['compress']:
            prune_oats_compress(model_adapter, args.sparsity, layerwise_sparsity_ratios, train_loader, args.prune_hyper, checkpoint_path, prune_n=prune_n, prune_m=prune_m)
        else:
            prune_oats(model_adapter, tokenizer, args.sparsity, layerwise_sparsity_ratios, train_loader, args.prune_hyper, checkpoint_path, prune_n=prune_n, prune_m=prune_m)
    elif args.prune_type == "QR":
        prune_qr_compress(model_adapter, args.sparsity, layerwise_sparsity_ratios, train_loader, args.prune_hyper, checkpoint_path, prune_n=prune_n, prune_m=prune_m)
    elif args.prune_type == "QR_rank":
        prune_QR_rank(model_adapter, args.sparsity, layerwise_sparsity_ratios, train_loader, args.prune_hyper, checkpoint_path, prune_n=prune_n, prune_m=prune_m)
    
    pruned_param_count = sum(int(torch.count_nonzero(p).item()) for p in model.parameters())
    pruned_fraction = 1.0 - (float(pruned_param_count) / float(original_param_count))
    results_stats["final_sparsity"] = pruned_fraction
    results_stats["pruned nnz"] = pruned_param_count
    results_stats["original nnz"] = original_param_count
    
    print(f'Pruned model parameters: {pruned_param_count:,d} (Sparsity: {pruned_fraction:.4f})')

    # =================================================

    # Run PPL Eval
    reset_model_device()
    eval_batch_size = "auto"

    # ============== Prune the redundant model ================


    # ============== Run PPL Eval ================
    
    if args.eval_ppl:
        hflm = HFLM(pretrained=model_adapter.model, tokenizer=tokenizer, batch_size=eval_batch_size) 
        with torch.no_grad():
            ppl_tasks = ["wikitext"]
            ppl_results = lm_eval.simple_evaluate(hflm, tasks=ppl_tasks, num_fewshot=None, batch_size=eval_batch_size)[
                    'results'
                ]
            print(ppl_results)
            
            ppl_vals = {task: round(result.get('word_perplexity,none', result['word_perplexity,none']), 4) for task, result in ppl_results.items()}

            for k, v in ppl_vals.items():
                results_stats["Task Name"] = k 
                results_stats["Task Score"] = v
                results.append(copy.deepcopy(results_stats))
                results.save()
            
            results.to_csv()

    # prune_Policy(model_adapter, args.sparsity_probmask, layerwise_sparsity_ratios, train_loader_probmask, args.prune_hyper, checkpoint_path, prune_n=prune_n, prune_m=prune_m)

    # ============== Run Zeroshot Eval ================

    if args.eval_zero_shot:
        hflm = HFLM(pretrained=model_adapter.model, tokenizer=tokenizer, batch_size=eval_batch_size) 
        with torch.no_grad():
            zero_shot_tasks = ["piqa", "hellaswag", "arc_easy", "arc_challenge", "winogrande", "rte", "openbookqa", "boolq"]

            ### LM Eval Harness ###
            zs_results = lm_eval.simple_evaluate(hflm, tasks=zero_shot_tasks, num_fewshot=0, batch_size=eval_batch_size)[
                'results'
            ]
            print(zs_results)

            metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in zs_results.items()}

            acc_avg = calculate_avg_accuracy(zero_shot_tasks, zs_results)

            metric_vals['average_zero_shot'] = round(acc_avg, 4)

            print(metric_vals)
            
            for k, v in metric_vals.items():
                results_stats["Task Name"] = k
                results_stats["Task Score"] = v
                results.append(copy.deepcopy(results_stats))
                results.save()
            
            results.to_csv()
    
    if args.eval_mmlu:
        hflm = HFLM(pretrained=model_adapter.model, tokenizer=tokenizer, batch_size=eval_batch_size) 
        with torch.no_grad():
            print("Evaluating MMLU!")

            mmlu_tasks = ["mmlu_abstract_algebra", "mmlu_business_ethics", "mmlu_college_computer_science", \
                            "mmlu_college_mathematics", "mmlu_conceptual_physics", "mmlu_formal_logic", "mmlu_machine_learning",\
                                "mmlu_miscellaneous", "mmlu_philosophy", "mmlu_global_facts"]
            
            mmlu_results = lm_eval.simple_evaluate(hflm, tasks=mmlu_tasks, num_fewshot=5, batch_size=eval_batch_size)[
                'results'
            ]

            print(mmlu_results)

            metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in mmlu_results.items()}

            mmlu_avg = calculate_avg_accuracy(mmlu_tasks, mmlu_results)
            metric_vals['average_mmlu'] = round(mmlu_avg, 4)

            print(metric_vals)
            
            for k, v in metric_vals.items():
                results_stats["Task Name"] = k
                results_stats["Task Score"] = v
                results.append(copy.deepcopy(results_stats))
                results.save()
            
            results.to_csv()

    return

if __name__ == "__main__":

    run_id = int(sys.argv[1])

    exper_config = prune_exper[run_id - 1]

    pruning_args =  ConfigDict(exper_config)
    process_pruning_args(pruning_args)

    results_path = "../" # Please fill in the desired directory for saving results
   
    checkpoint_path = "../" # Please fill in the desired directoy for checkpointing the pruned model
    pathlib.Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

    pruning_main(pruning_args, checkpoint_path, results_path)
    # pruning_probmask(pruning_args, checkpoint_path, results_path)
