import argparse
import numpy as np
import torch
from transformers import AutoTokenizer
from models.hf_llama.modeling_llama import LlamaForCausalLM

from importlib.metadata import version
import functools
import openai

from lib.prune import prune_wanda_sp, check_sparsity
from lib.eval import eval_ppl_wikitext
from lib.data import get_loaders
from lib.openaiserver import call_openai_server_func

print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model, cache_dir="llm_weights"):
    model = LlamaForCausalLM.from_pretrained(
        model, 
        torch_dtype=torch.float16, 
        cache_dir=cache_dir, 
        device_map="auto"
    )
        
    model.seqlen = 128
    return model

# select topk
def select(candidates, keep_top_50, all_top_k):
    print('select ......', flush=True)
    keep_top_50.extend(candidates)
    keep_top_50 = sorted(keep_top_50, key=lambda can:can[-1], reverse=False)
    return keep_top_50[:all_top_k]

def infer(args, model, tokenizer, dataloader, testenc, all_layer_ratio, post_evo=False):
    prune_wanda_sp(args, model, dataloader=dataloader, all_layer_ratio=all_layer_ratio)

    # Check the sparsity of the model
    print("*"*30)
    sparsity_ratio = check_sparsity(model)
    print(f"sparsity sanity check {sparsity_ratio:.4f}")
    print("*"*30)

    if post_evo:
        truncated_nsamples = None
    else:
        truncated_nsamples = 10

    truncated_nsamples = None

    print(f"evaluating on wikitext2")
    model.seqlen = 2048
    with torch.no_grad():
        ppl = eval_ppl_wikitext(model, testenc, truncated_nsamples=truncated_nsamples)
    model.seqlen = 128

    print('ppl = {:.4f}'.format(ppl), flush=True)

    return ppl

# prepare ids for testing
def test_candidates_model(args, model, tokenizer, dataloader, testenc, candidates, all_population_num, test_dict):
    for can in candidates:
        print('test No.{} model'.format(all_population_num), flush=True)

        t_can = tuple(can[:-1])
        print(t_can, flush=True)

        if t_can in test_dict.keys():
            ppl = test_dict[t_can]
            print('Already tested ppl = {:.4f}'.format(ppl))
        else:
            ppl = infer(args, model, tokenizer, dataloader, testenc, can[:-1])
            test_dict[t_can] = ppl
        all_population_num += 1
        can[-1] = ppl

    return candidates, all_population_num

# random operation in evolution algorithm
def random_can_with_gpt(args, model, random_num, test_dict, call_optimizer_server_func):
    def generate_random_numbers(args, model_layer_len, generate_random_num):
        import re
        # prompt
        prompt_template = (
            "Let's think step by step! You are helping me prune the {model_name} model, aiming to minimize perplexity on the WikiText-2 dataset. The model has {model_layer_len} transformer layers. Layer-wise pruning rate measures how many parameters are pruned from each layer of the model. Different layers may have different pruning rates based on their importance and contribution to the performance of model. You need to generate {generate_random_num} valid layer-wise pruning rate configurations. Each configuration should:\n"
            "- Contain {model_layer_len} decimals between 0 and 1, accurate to 5 decimal places.\n"
            "- Ensure the average of these numbers equals {pruning_ratio}.\n"
            "- Be distinct, starting with <begin> and ending with <end>\n"
            "Your response should only contain the {generate_random_num} configurations without any additional text."
        )
        prompt = prompt_template.format(
            model_name=args.model_name, 
            model_layer_len=model_layer_len,
            pruning_ratio=args.pruning_ratio,
            generate_random_num=generate_random_num,
        )



        # Call the optimizer server function
        gpt_outputs = call_optimizer_server_func(prompt, temperature=1.0)

        def extract_substrings(text):
            pattern = r'<begin>(.*?)<end>'
            matches = re.findall(pattern, text, re.DOTALL)
            
            return matches

        results = extract_substrings(gpt_outputs[0])

        random_numbers = []
        for result in results:
            pattern = r'\b\d+\.\d+\b'
            matches = re.findall(pattern, str(result))
            random_number = [float(number) for number in matches]
            random_numbers.append(random_number)

        print(random_numbers)
        
        return random_numbers
    
    candidates = []
    while len(candidates) < random_num:
        generate_random_num = min(5, random_num-len(candidates))
        cans = generate_random_numbers(args, len(model.model.layers), generate_random_num)

        for can in cans:
            illegal_list = False
            can_sum = 0
            for can_num in can:
                if can_num <= 0 or can_num >= 1:
                    illegal_list = True 
                    break
                can_sum += can_num
            can_avg = can_sum/len(can)
            if can_avg < args.pruning_ratio-0.01 or can_avg > args.pruning_ratio+0.01:
                continue

            if len(can) != len(model.model.layers):
                continue

            if illegal_list:
                continue
            
            can.append(0)
            t_can = tuple(can[:-1])

            if t_can in test_dict.keys():
                continue
            candidates.append(can)

            print('No.{} GPT generated random num, {}'.format(len(candidates), t_can))

    print('all random num :{}'.format(candidates), flush=True)
    print('random_num = {}'.format(len(candidates)), flush=True)
    return candidates

# crossover operation in evolution algorithm
def get_crossover_with_gpt(args, model, keep_top_k, crossover_num, test_dict, call_optimizer_server_func):
    def generate_random_numbers(args, model_layer_len, keep_top_k, generate_random_num):
        import re
        # prompt
        prompt_template = (
            "Let's think step by step! You will receive {keep_top_k_len} lists representing the layer-wise pruning rates of the {model_name} and a fitness value for each list. The lower the fitness value, the better. Your task is to perform the crossover operation in the evolutionary algorithm to generate new configurations. Each new pruning rate configuration list should:\n"
            "- Contain {model_layer_len} decimals between 0 and 1, accurate to 5 decimal places.\n"
            "- Ensure the average of these numbers equals {pruning_ratio}.\n"
            "- Be distinctive, with each configuration starting with <begin> and ending with <end>.\n"
            "Please provide exactly {generate_random_num} new configurations based on the existing data provided below without any additional text.\n"
            "Here are the existing layer-wise pruning rate configurations and their fitness values:\n"
        )
        prompt = prompt_template.format(
            keep_top_k_len = len(keep_top_k),
            model_name=args.model_name, 
            model_layer_len=model_layer_len,
            pruning_ratio=args.pruning_ratio,
            generate_random_num=generate_random_num,
        )


        for keep_k in keep_top_k:
            prompt += f'\nConfiguration: {keep_k[:-1]}, Fitness: {keep_k[-1]}'


        # Call the optimizer server function
        gpt_outputs = call_optimizer_server_func(prompt, temperature=1.0)

        def extract_substrings(text):
            pattern = r'<begin>(.*?)<end>'
            matches = re.findall(pattern, text, re.DOTALL)
            
            return matches

        results = extract_substrings(gpt_outputs[0])

        random_numbers = []
        for result in results:
            pattern = r'\b\d+\.\d+\b'
            matches = re.findall(pattern, str(result))
            random_number = [float(number) for number in matches]
            random_numbers.append(random_number)
        
        return random_numbers
    
    print('crossover ......', flush=True)
    crossover_res = []

    while len(crossover_res) < crossover_num:
        generate_random_num = min(5, crossover_num-len(crossover_res))
        cans = generate_random_numbers(args, len(model.model.layers), keep_top_k, generate_random_num)
        for can in cans:
            illegal_list = False
            can_sum = 0
            for can_num in can:
                if can_num<=0 or can_num>=1:
                    illegal_list = True 
                    break
                can_sum += can_num
            
            can_avg = can_sum/len(can)
            if can_avg < args.pruning_ratio-0.01 or can_avg > args.pruning_ratio+0.01:
                continue

            if len(can) != len(model.model.layers):
                continue

            if illegal_list:
                continue

            can.append(0)
            t_can = tuple(can[:-1])
            if t_can in test_dict.keys():
                continue
            crossover_res.append(can)
            print('No.{} GPT generated crossover num, {}'.format(len(crossover_res), t_can))
    
    print('all crossover num :{}'.format(crossover_res), flush=True)
    print('crossover_num = {}'.format(len(crossover_res)), flush=True)
    return crossover_res

# mutation operation in evolution algorithm
def get_mutation_with_gpt(args, model, keep_top_k, mutation_num, test_dict, call_optimizer_server_func):
    def generate_random_numbers(args, model_layer_len, keep_top_k, generate_random_num):
        import re
        # prompt
        prompt_template = (
            "Let's think step by step! You will receive {keep_top_k_len} lists representing the layer-wise pruning rates of the {model_name} and a fitness value for each list. The lower the fitness value, the better. Your task is to perform the mutation operation in the evolutionary algorithm to generate new configurations. Each new pruning rate configuration list should:\n"
            "- Contain {model_layer_len} decimals between 0 and 1, accurate to 5 decimal places.\n"
            "- Ensure the average of these numbers equals {pruning_ratio}.\n"
            "- Be distinctive, with each configuration starting with <begin> and ending with <end>.\n"
            "Please provide exactly {generate_random_num} new configurations based on the existing data provided below without any additional text.\n"
            "Here are the existing layer-wise pruning rate configurations and their fitness values:\n"
        )
        prompt = prompt_template.format(
            keep_top_k_len = len(keep_top_k),
            model_name=args.model_name, 
            model_layer_len=model_layer_len,
            pruning_ratio=args.pruning_ratio,
            generate_random_num=generate_random_num,
        )


        for keep_k in keep_top_k:
            prompt += f'\nConfiguration: {keep_k[:-1]}, Fitness: {keep_k[-1]}'


        # Call the optimizer server function
        gpt_outputs = call_optimizer_server_func(prompt, temperature=1.0)

        def extract_substrings(text):
            pattern = r'<begin>(.*?)<end>'
            matches = re.findall(pattern, text, re.DOTALL)
            return matches

        results = extract_substrings(gpt_outputs[0])

        random_numbers = []
        for result in results:
            pattern = r'\b\d+\.\d+\b'
            matches = re.findall(pattern, str(result))
            random_number = [float(number) for number in matches]
            random_numbers.append(random_number)

        print(random_numbers)
        
        return random_numbers
    
    print('mutation ......', flush=True)
    mutation_res = []

    while len(mutation_res) < mutation_num:
        generate_random_num = min(5, mutation_num-len(mutation_res))
        cans = generate_random_numbers(args, len(model.model.layers), keep_top_k, generate_random_num)
        for can in cans:
            illegal_list = False
            can_sum = 0
            for can_num in can:
                if can_num<=0 or can_num>=1:
                    illegal_list = True 
                    break
                can_sum += can_num
            
            can_avg = can_sum/len(can)
            if can_avg < args.pruning_ratio-0.01 or can_avg > args.pruning_ratio+0.01:
                continue

            if len(can) != len(model.model.layers):
                continue

            if illegal_list:
                continue

            can.append(0)
            t_can = tuple(can[:-1])
            if t_can in test_dict.keys():
                continue
            mutation_res.append(can)
            print('No.{} GPT generated crossover num, {}'.format(len(mutation_res), t_can))
    
    print('all mutation num :{}'.format(mutation_res), flush=True)
    print('mutation_num = {}'.format(len(mutation_res)), flush=True)
    return mutation_res


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help='LLaMA model')    # Huggingface model name
    parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=2048, help='Number of calibration samples.')
    parser.add_argument('--pruning_ratio', type=float, default=0, help='Pruning ratio.')
    parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
    parser.add_argument('--input-path', type=str, default=None)
    parser.add_argument('--output-path', type=str, default=None)

    # options for loading models
    parser.add_argument('--model_name', type=str, default=None)
    parser.add_argument("--cache-dir", type=str, default=None)
    parser.add_argument('--gpt', type=str, default=None)
    parser.add_argument("--api-key", type=str, default=None)
    args = parser.parse_args()

    # Setting seeds for reproducibility
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)


    # Build the model and tokenizer
    print(f"loading llm model {args.model}")
    model = get_llm(args.model, args.cache_dir)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)

    print("loading calibdation data")
    dataloader, _ = get_loaders("c4",nsamples=128,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
    print("dataset loading complete")

    # get wikitext2 loaders
    model.seqlen = 2048
    _, testenc = get_loaders("wikitext2", seed=0, seqlen=model.seqlen, tokenizer=tokenizer)
    model.seqlen = 128

    population_num = 30
    mutation_num = 10
    crossover_num = 10
    max_iters = 20
    top_k = 30
    all_top_k = 30
    all_population_num = 1

    test_dict = {}
    keep_top_k = [] 
    keep_top_50 = [] 
    candidates = []
    print('population_num = {} all_top_k = {} mutation_num = {} crossover_num = {} max_iters = {}'
          .format(population_num, all_top_k, mutation_num, crossover_num, max_iters))

    # ================ load LLM settings ===================
    optimizer_llm_name = args.gpt
    openai_api_key = args.api_key

    openai.api_key = openai_api_key


    # ====================== optimizer model configs ============================
    optimizer_gpt_max_decode_steps = 1024
    optimizer_gpt_temperature = 1.0

    optimizer_llm_dict = dict()
    optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps
    optimizer_llm_dict["temperature"] = optimizer_gpt_temperature
    optimizer_llm_dict["batch_size"] = 1
    call_optimizer_server_func = functools.partial(
        call_openai_server_func,
        model=optimizer_llm_name,
        max_decode_steps=optimizer_gpt_max_decode_steps,
        temperature=optimizer_gpt_temperature,
    )
    
    print('*********candidates are first generated randomly*********')
    candidates = random_can_with_gpt(args, model, population_num, 
                                     test_dict, call_optimizer_server_func)

    for iter in range(0, max_iters):
        candidates, all_population_num = test_candidates_model(args, model, tokenizer, dataloader, testenc, 
                                                candidates, all_population_num, test_dict)
        
        keep_top_50 = select(candidates, keep_top_50, all_top_k)
        keep_top_k = keep_top_50[0:top_k]

        print('iter = {} : top {} result'.format(iter+1, all_top_k), flush=True)

        for i in range(all_top_k):
            res = keep_top_50[i]
            print('No.{} {} ppl = {}'.format(i+1, res[:-1], res[-1]))

        if iter < max_iters-1:
            mutation = get_mutation_with_gpt(args, model, keep_top_k, mutation_num,
                                            test_dict, call_optimizer_server_func)
            crossover = get_crossover_with_gpt(args, model, keep_top_k, crossover_num, 
                                            test_dict, call_optimizer_server_func)
            candidates.extend(mutation)
            candidates.extend(crossover)

    print(keep_top_k)
    print('Finish Evolutionary Search!')

if __name__ == '__main__':
    main()