import argparse
import os 
import numpy as np
import torch
from transformers import GPTNeoXTokenizerFast
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version

from lib.prune_opt import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers, prune_pruner_zero
from lib.eval import eval_ppl, eval_zero_shot
from lib.gptree import GPTree
from lib.rl_search import online_rl_search

print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name, cache_dir="llm_weights"):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        cache_dir=cache_dir, 
        low_cpu_mem_usage=True, 
        device_map="auto",
        trust_remote_code=True
    )

    model.seqlen = model.config.max_position_embeddings
    return model

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help='LLaMA model')
    parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
    parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
    parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt", 
                        "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search", "pruner-zero", "x-pruner"])
    parser.add_argument("--cache_dir", default="llm_weights", type=str )
    parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
    parser.add_argument('--save', type=str, default=None, help='Path to save results.')
    parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
    parser.add_argument("--gradient_path", type=str, default=None, help="Path to save the gradient.")
    parser.add_argument("--eval_zero_shot", action="store_true")

    parser.add_argument("--search_global", type=str, default=None, help='grid search or not')
    parser.add_argument('--w_range', type=str, default="0.5,2.5,0.1", help='w exponent range: start,end,step')
    parser.add_argument('--g_range', type=str, default="0.5,2.5,0.1", help='g exponent range: start,end,step')

    args = parser.parse_args()

    # Setting seeds for reproducibility
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    # Handling n:m sparsity
    prune_n, prune_m = 0, 0
    if args.sparsity_type != "unstructured":
        assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
        prune_n, prune_m = map(int, args.sparsity_type.split(":"))


    model_name = args.model.split("/")[-1]
    print(f"\nloading llm model {args.model}")
    model = get_llm(args.model, args.cache_dir)
    model.eval()
    # tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    if "Qwen" in args.model:
        tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    elif "Llama" in args.model:
        tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    else:
        tokenizer = GPTNeoXTokenizerFast.from_pretrained(args.model, use_fast=False)

    device = torch.device("cuda:0")
    if "30b" in args.model or "66b" in args.model or "70b" in args.model or "33b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
        device = model.hf_device_map["lm_head"]
    print("use device ", device)


    def parse_range(range_str):
        start, end, step = map(float, range_str.split(','))
        return np.round(np.arange(start, end+step, step), 4)

    w_values = parse_range(args.w_range)
    g_values = parse_range(args.g_range)

    engine = GPTree.load_tree(f"./data/x-pruner.json")

    search_global = args.search_global
    if search_global is not None and search_global.lower() == "true":
        manual_layers = {}
        layer_sparsity = {}
        layer_order = []

        for layer_idx in layer_order:
            if layer_idx in manual_layers:
                best_w, best_g = manual_layers[layer_idx]
                print(f"Using Manual Setting: w={best_w:.4f}, g={best_g:.4f}")

                GPTree.PowerExponents.set_layer_exponents(layer_idx, best_w, best_g)
                prune_x_pruner(args, model, tokenizer, device, engine=engine, target_layers=[layer_idx], layer_sparsity=layer_sparsity)

                original_state = copy.deepcopy(model.state_dict())
                current_ppl = eval_ppl(args, model, tokenizer, device)
                continue
            
            online_rl_search_new(args, model, tokenizer, device, engine, layer_idx, layer_sparsity=layer_sparsity)
            original_state = copy.deepcopy(model.state_dict())        
    
    model = get_llm(args.model, args.cache_dir)

    GPTree.PowerExponents.layer_exponents = {}
    layer_sparsity = {}

    if args.sparsity_ratio != 0:
        print("pruning starts")
        if args.prune_method == "wanda":
            prune_wanda(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
        elif args.prune_method == "magnitude":
            prune_magnitude(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
        elif args.prune_method == "sparsegpt":
            prune_sparsegpt(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
        elif "ablate" in args.prune_method:
            prune_ablate(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
        elif "pruner-zero" in args.prune_method:
            engine = GPTree.load_tree(f"./data/pruner-zero.json")
            prune_pruner_zero(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m, engine=engine, layer_sparsity=layer_sparsity)
        elif "x-pruner" in args.prune_method:
            engine = GPTree.load_tree(f"./data/x-pruner.json")
            prune_x_pruner(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m, engine=engine, layer_sparsity=layer_sparsity)


    ################################################################
    print("*"*30)
    sparsity_ratio = check_sparsity(model)
    print(f"sparsity sanity check {sparsity_ratio:.4f}")
    print("*"*30)
    ################################################################
    model.config.use_cache = False
    model.seqlen = 512
    ppl_test = eval_ppl(args, model, tokenizer, device)
    print(f"wikitext perplexity:{ppl_test}")

    if not os.path.exists(args.save):
            os.makedirs(args.save)
    save_filepath = os.path.join(args.save, f"log_{args.prune_method}.txt")
    with open(save_filepath, "w") as f:
        print("method\tactual_sparsity\tppl_test", file=f, flush=True)
        print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True)

    if args.eval_zero_shot:
        accelerate=False
        if "30b" in args.model or "66b" in args.model or "70b" in args.model or "33b" in args.model or "6.7b" in args.model:
            accelerate=True

        task_list = ["boolq", "rte","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa"]
        num_shot = 3
        results = eval_zero_shot(args.model, model, tokenizer, task_list, num_shot, accelerate)
        print("********************************")
        print("zero_shot evaluation results")
        print(results)

    if args.save_model:
        model.save_pretrained(args.save_model)
        tokenizer.save_pretrained(args.save_model)

if __name__ == '__main__':
    main()
