import argparse
import os

from importlib.metadata import version

from compression.model_prune import *
from compression.preprocess import *
from model_impl.llama import LlamaForCausalLM_SALS

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from utils.model_loader import *
from utils.data import set_seed
torch.set_printoptions(threshold=torch.inf)
print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def generate_preprocess(model_name, tokenizer, model, config, nsamples, method, device):
    preprocess_path = f'./preprocess_output/'
    method_name = method
    # method_name = 'after_rope'
    if not os.path.exists(preprocess_path):
        os.mkdir(preprocess_path)
    # method_name = "triton"
    if method == "label_sparse_gqa_before":
        method_name = "label_sparse_gqa"
    preprocess_pth_path = preprocess_path + f'{model_name}_{nsamples}{method_name if method_name is not None else ""}.pth'
    print(preprocess_pth_path)
    if os.path.exists(preprocess_pth_path):
        print("loading preprocess pth from : ", preprocess_pth_path)
        layer_lst = torch.load(preprocess_pth_path)
    else:
        print("counting preprocess data")
        # layer_lst = feature_base_key_allocate_q_label_sparse(model, tokenizer, nsamples, config, device)
        if method == "label_sparse" or "SALS" in method:
            layer_lst = feature_base_key_allocate_label_sparse(model, tokenizer, nsamples, config, device)
        else:
            raise NotImplementedError
        torch.save(layer_lst, preprocess_pth_path)
    return layer_lst

def compress(model_name_or_path, config, device, save=True):
    config.is_compress = False
    model, tokenizer = load_model_and_tokenizer(model_name_or_path, config, preprocessing=True)
    model_name = model_name_or_path.split("/")[-1]
    layer_lst = generate_preprocess(model_name, tokenizer, model, config, config.nsamples, config.method, device)
    if config.method == "label_sparse" or "SALS" in config.method:
        prune_mix_label_sparse(model, layer_lst, config, device)
    else:
        raise NotImplementedError
    # prune_mix_q_label_sparse(model, layer_lst, config, device)

    if save:
        save_model_and_tokenizer(model, tokenizer, model_name_or_path, config)
    else:
        return model, tokenizer

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help='LLaMA model')
    parser.add_argument('--seed', type=int, default=42, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--k_rank', type=int, default=512, help='Number of calibration samples.')
    parser.add_argument('--k_high_rank', type=int, default=2048, help='Number of calibration samples.')
    parser.add_argument('--v_bits', type=int, default=2, help='Number of calibration samples.')
    parser.add_argument('--groupsize', type=int, default=32, help='Number of calibration samples.')
    parser.add_argument('--residual_length', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--method', type=str, default="before_rope") 
    parser.add_argument('--adapt_method', type=str, default=None) 
    parser.add_argument('--sparsity', type=int, default=16, help='Number of calibration samples.')
    parser.add_argument('--sparse_rank', type=int, default=8, help='Number of calibration samples.')

    args = parser.parse_args()
    print(args)
    set_seed(args.seed)

    if torch.cuda.device_count() > 1:
        parallel = True
        low_cpu_mem_usage=True
    else:
        parallel = False
        low_cpu_mem_usage=True
    
    dtype = torch.float16
    model_name = args.model.split("/")[-1]
    config_path = os.path.join(args.model, "config.json")
    config = AutoConfig.from_pretrained(config_path)
    config.k_rank = args.k_rank
    config.k_high_rank = args.k_high_rank
    config.k_bits = args.k_bits
    config.v_bits = args.v_bits
    config.group_size = args.groupsize
    config.residual_length = args.residual_length
    config.nsamples = args.nsamples
    config.sparsity = args.sparsity
    config.sparse_rank = args.sparse_rank
    config.method = args.method
    config.finetuned = args.withfinetune
    config.withsoftmax = args.withsoftmax
    config.skip_layers = args.skip_layers
    config.budget = [int(n) for n in args.budget.split(',')] if args.budget is not None else None
    compress(args.model, config, "cuda", save=True)



    



