import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version

from lib.prune import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, prune_rpca, check_sparsity, find_layers
from lib.eval import eval_ppl, eval_zero_shot

import matplotlib.pyplot as plt
import re
import json

print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name, cache_dir="llm_weights", device='cuda'):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        cache_dir=cache_dir,
        device_map=None  # We will move the model to the specified device later
    )

    model.seqlen = model.config.max_position_embeddings
    model.to(device)
    return model

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help='LLaMA model')
    parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
    parser.add_argument("--sparsity_type", type=str, default="unstructured", choices=["unstructured", "4:8", "2:4"])
    parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
                        "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search", "rpca", "rpca_info"])
    parser.add_argument("--cache_dir", default="llm_weights", type=str )
    parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
    parser.add_argument('--save', type=str, default=None, help='Path to save results.')
    parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')

    parser.add_argument("--eval_zero_shot", action="store_true")
    # RPCA parameters
    parser.add_argument("--linear_layers", type=str, default="q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj",
                        help="Comma-separated list of linear layers to process")
    parser.add_argument('--rpca_device', type=str, default='cuda', help='Device to run RPCA on')
    parser.add_argument('--rpca_max_iter', type=int, default=1000, help='Maximum iterations for RPCA')
    parser.add_argument('--rpca_tol', type=float, default=1e-7, help='Convergence threshold for RPCA')
    parser.add_argument('--rpca_lambda', type=float, default=None, help='Lambda parameter for RPCA')
    parser.add_argument('--rpca_mu', type=float, default=None, help='Mu parameter for RPCA')
    parser.add_argument('--tau_multiplier', type=float, default=1.0, help='Multiplier for tau in RPCA') 
    parser.add_argument('--target_sparsity', type=float, default=None, help='Target sparsity level for the sparse matrix (e.g., 0.5 for 50% sparsity)')
    parser.add_argument('--target_rank', type=int, default=None, help='Target rank for low-rank matrix after SVD')
    parser.add_argument('--eval_after_each_layer', action='store_true', help='Evaluate after each layer replacement during RPCA')
    parser.add_argument('--use_nuclear_norm', action='store_true', help='Use nuclear norm in RPCA')
    parser.add_argument('--finetune', action='store_true', help='Finetune model after RPCA')
    parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate for finetuning')
    parser.add_argument('--num_train_epochs', type=int, default=10, help='Number of epochs for finetuning')
    parser.add_argument('--train_samples', type=int, default=512, help='Number of training samples for finetuning')
    parser.add_argument('--gpu_id', type=int, default=1, help='GPU id to use')

    # Added arguments
    parser.add_argument('--singular_value_threshold', type=float, default=1e-4, help='Threshold for singular values to determine rank')
    parser.add_argument('--enforce_target_sparsity', action='store_true', help='Whether to enforce target sparsity on S after RPCA')
    parser.add_argument('--zeroing_method', type=str, choices=['random', 'magnitude'], default='magnitude', help='Method to zero out additional elements in S to reach target sparsity')
    parser.add_argument('--save_rpca', action='store_true', help='Save RPCA results to disk')  
    parser.add_argument('--load_rpca', action='store_true', help='Load RPCA results from disk') 

    args = parser.parse_args()

    # Set GPU device
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print('# of gpus: ', num_gpus)
        if args.gpu_id >= num_gpus:
            raise ValueError(f"Invalid gpu_id {args.gpu_id}, only {num_gpus} GPUs available.")
        device = torch.device(f'cuda:{args.gpu_id}')
        # Do not set torch.cuda.set_device(device) here
    else:
        print("CUDA is not available")
        device = torch.device('cpu')

    # Set random seed
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    # Handle n:m sparsity
    prune_n, prune_m = 0, 0
    if args.sparsity_type != "unstructured":
        assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
        prune_n, prune_m = map(int, args.sparsity_type.split(":"))

    model_name = args.model.split("/")[-1]
    print(f"loading llm model {args.model}")
    model = get_llm(args.model, args.cache_dir, device)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)

    print("use device ", device)

    if args.sparsity_ratio != 0 or args.prune_method in ['rpca', 'rpca_info']:
        print("pruning starts")
        if args.prune_method == "wanda":
            prune_wanda(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
        elif args.prune_method == "magnitude":
            prune_magnitude(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
        elif args.prune_method == "sparsegpt":
            prune_sparsegpt(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
        elif "ablate" in args.prune_method:
            prune_ablate(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
        elif args.prune_method == "rpca":
            prune_rpca(args, model, tokenizer, device)
        elif args.prune_method == "rpca_info":
            prune_rpca_info(args, model, tokenizer, device)

    ################################################################
    print("*" * 30)
    sparsity_ratio = check_sparsity(model)
    print(f"sparsity sanity check {sparsity_ratio:.4f}")
    print("*" * 30)
    ################################################################
    ppl_test = eval_ppl(args, model, tokenizer, device)
    print(f"wikitext perplexity {ppl_test}")

    if args.save and not os.path.exists(args.save):
        os.makedirs(args.save)
    if args.save:
        save_filepath = os.path.join(args.save, f"log_{args.prune_method}.txt")
        with open(save_filepath, "w") as f:
            print("method\tactual_sparsity\tppl_test", file=f, flush=True)
            print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True)

    if args.eval_zero_shot:
        accelerate = False
        if "30b" in args.model or "65b" in args.model or "70b" in args.model:
            accelerate = True

        task_list = ["boolq", "rte", "hellaswag", "winogrande", "arc_easy", "arc_challenge", "openbookqa"]
        num_shot = 0
        results = eval_zero_shot(args.model, model, tokenizer, task_list, num_shot, accelerate)
        print("********************************")
        print("zero_shot evaluation results")
        print(results)

    if args.save_model:
        # Save the model state dict
        model.save_pretrained(args.save_model)
        tokenizer.save_pretrained(args.save_model)
        # Additionally, save the model with custom layers
        torch.save(model.state_dict(), os.path.join(args.save_model, 'pytorch_model.bin'))
        print(f"Model saved to {args.save_model}")

if __name__ == '__main__':
    main()
