from data_utils import load_dataset
from utils import construct_prompt, random_sampling, construct_prompt_without_test, construct_prompt_instance_prompt_text
import numpy as np
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, LlamaForSequenceRank,  AutoConfig, AutoModelForCausalLM, AutoTokenizer #, LLaMATokenizer
import argparse
from typing import Dict, Optional, Sequence
import itertools
import copy
import json
import random
# import deepcopy
def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    # model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    print("num_new_tokens = ", num_new_tokens)
    return num_new_tokens

def load_model_lora(base_model, device, lora_weights, compression, compression_length, use_partial_mask):
    config = AutoConfig.from_pretrained(
        base_model,
        cache_dir='.cache',
    )
    if compression:
        config.compression_size = compression_length
        config.use_partial_mask = use_partial_mask
        print('use_partial_mask = ', use_partial_mask)
        
    if device == "cuda":
        if compression:
            pass
            # model = LlamaForCompressionCausalLM.from_pretrained(
            #     base_model,
            #     load_in_8bit=False,
            #     torch_dtype=torch.float16,
            #     device_map="auto",
            #     config=config,
            # )
        else:
            model = LlamaForCausalLM.from_pretrained(
                base_model,
                load_in_8bit=False,
                torch_dtype=torch.float16,
                device_map="auto",
            )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            torch_dtype=torch.float16,
        )
    elif device == "mps":
        if compression:
            pass
            # model = LlamaForCompressionCausalLM.from_pretrained(
            #     base_model,
            #     device_map={"": device},
            #     torch_dtype=torch.float16,
            #     config=config,
            # )
        else:
            model = LlamaForCausalLM.from_pretrained(
                base_model,
                device_map={"": device},
                torch_dtype=torch.float16,
            )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
            torch_dtype=torch.float16,
        )
    else:
        if compression:
            # model = LlamaForCompressionCausalLM.from_pretrained(
            #     base_model, device_map={"": device}, low_cpu_mem_usage=True,
            #     config=config,
            # )
            pass
        else:
            model = LlamaForCausalLM.from_pretrained(
                base_model, device_map={"": device}, low_cpu_mem_usage=True
            )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
        )
    return model

def main(model, lora_weight, dataset, num_seeds, all_shots, subsample_test_set, compression, compression_length, use_partial_mask, compression_without_input, with_prompt_text, compression_without_prompt_text, with_sequence_order, compression_token_initialization, add_final_article):
    test_inference = True

    print("lora_weight = ", lora_weight)
    print("model = ", model)
    print("dataset = ", dataset)
    print("with_prompt_text = ", with_prompt_text)
    print("compression_without_prompt_text = ", compression_without_prompt_text)




    if test_inference:
        if torch.cuda.is_available():
            device = "cuda"
        else:
            device = "cpu"
        base_model = 'decapoda-research/llama-7b-hf' if lora_weight is not None else model
        # base_model = model

        

        # model = LlamaForCausalLM.from_pretrained(
        #     base_model, device_map={"": device}, low_cpu_mem_usage=True, cache_dir='.cache',
        # )
        # model = PeftModel.from_pretrained(
        #     model,
        #     lora_weights,
        #     device_map={"": device},
        # )
        if model == 'decapoda-research/llama-7b-hf':
            tokenizer = LlamaTokenizer.from_pretrained('decapoda-research/llama-7b-hf')
        elif model == 'openlm-research/open_llama_3b':
            tokenizer = LlamaTokenizer.from_pretrained('openlm-research/open_llama_3b')
        else:
            tokenizer = LlamaTokenizer.from_pretrained(base_model)
        # if lora_weight is None:

        IGNORE_INDEX = -100
        DEFAULT_PAD_TOKEN = "[PAD]"
        DEFAULT_EOS_TOKEN = "</s>"
        DEFAULT_BOS_TOKEN = "<s>"
        DEFAULT_UNK_TOKEN = "<unk>"
        special_tokens_dict = dict()
        if tokenizer.pad_token is None:
            special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
        if tokenizer.eos_token is None:
            special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
        if tokenizer.bos_token is None:
            special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
        if tokenizer.unk_token is None:
            special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

        
        num_new_tokens = smart_tokenizer_and_embedding_resize(
            special_tokens_dict=special_tokens_dict,
            tokenizer=tokenizer,
            # model=model,
        )

        if lora_weight is not None:
            model = load_model_lora(base_model, device, lora_weight, compression, compression_length, use_partial_mask)
                        # unwind broken decapoda-research config
            # model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
            # model.config.bos_token_id = 1
            # model.config.eos_token_id = 2

            print("model = ", model)
        else:
            # if model == 'decapoda-research/llama-7b-hf':
            #     config = transformers.AutoConfig.from_pretrained(
            #         'decapoda-research/llama-7b-hf',
            #         cache_dir='.cache',
            #     )
            # elif model == 'openlm-research/open_llama_3b':
                # 'openlm-research/open_llama_3b',
            config = transformers.AutoConfig.from_pretrained(
                base_model,
                cache_dir='.cache',
            )


            config.compression_token_initialization = compression_token_initialization
            if compression_token_initialization:
                config.initialize_ids = tokenizer("Article: N/A \n\n Answer: N/A \n\n", return_tensors="pt")['input_ids'].tolist()
            if compression: 
                
                config.vocab_size += num_new_tokens
                config.compression_size = compression_length
                config.use_partial_mask = use_partial_mask
                print("use_partial_mask = ", use_partial_mask)

                model = LlamaForCompressionCausalLM.from_pretrained(
                    base_model,
                    config=config,
                    # load_in_8bit=load_8bit,
                    # tie_weights=True,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    cache_dir='.cache',
                )                
            else:
                if '13b' in base_model:
                    model = LlamaForSequenceRank.from_pretrained(
                        base_model,
                        torch_dtype=torch.float16,
                        load_in_8bit=True,    # changing this to load_in_8bit=True works on smaller models
                        trust_remote_code=True,
                        device_map="auto",    # finds GPU
                        cache_dir='.cache',
                    )

                elif '70b' in base_model or '65b' in base_model:
                    model = LlamaForSequenceRank.from_pretrained(
                        base_model,
                        torch_dtype=torch.float16,
                        load_in_4bit=True,    # changing this to load_in_8bit=True works on smaller models
                        trust_remote_code=True,
                        device_map="auto",    # finds GPU
                        cache_dir='.cache',
                    )

                else:
                    model = LlamaForSequenceRank.from_pretrained(
                        base_model,
                        config=config,
                        # load_in_8bit=load_8bit,
                        # tie_weights=True,
                        torch_dtype=torch.float16,
                        device_map="auto",
                        cache_dir='.cache',
                    )

            
        # num_new_tokens = smart_tokenizer_and_embedding_resize(
        #     special_tokens_dict=special_tokens_dict,
        #     tokenizer=tokenizer,
        #     # model=model,
        # )
        temperature=0.8
        top_p=0.75
        top_k=40
        num_beams=4
        max_new_tokens=2 if compression else 1
        generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            # use_cache=True,
        )

    # dataset_name = ['agnews', 'atis', 'cb', 'dbpedia', 'lama', 'rte', 'slot-movies', 'sst2', 'trec']

     


    # if dataset == 'atis':
    model.eval()


    
    # Different ids got different training set, so we can not just mix them up.
    if dataset == "lama":
        all_lamas = [1001,101,103,106,108,127,1303,131,136,1376,138,140,1412,159,17,176,178,19,190,20,264,27,276,279,30,31,36,361,364,37,39,407,413,449,463,47,495,527,530,740,937]
        all_params = []
        for which_lama in all_lamas:
            # p = deepcopy(default_params)
            p = {}
            p['dataset'] = f"lama_{which_lama}"
            all_params.append(p)
        
        correct = 0
        test_samples_num = 0
        samples_num = 0
        for param_index, params in enumerate(all_params):
            orig_train_sentences, orig_train_labels, orig_test_sentences, orig_test_labels = load_dataset(params)
            
            np.random.seed(num_seeds)
            # AgNews 7600
            # samples_num = subsample_test_set
            few_shot = all_shots
            if test_samples_num == 0:
                test_sentences, test_labels = orig_test_sentences, orig_test_labels
                samples_num += len(orig_test_labels)

            else:
                test_sentences, test_labels = random_sampling(orig_test_sentences, orig_test_labels, samples_num)

            train_sentences, train_labels = random_sampling(orig_train_sentences, orig_train_labels, few_shot)
            # print("----------Train-----------")
            # print(orig_train_sentences[:5], orig_train_labels[:5])
            # print("----------Test-----------")
            # print(test_sentences[0], test_labels[0])



            # print('----------Prompt--------------')

            for i, each_test in enumerate(test_sentences):
                # print("i = ", i, end='\r')
                prompt = construct_prompt(params, train_sentences, train_labels, each_test)
                # print("constructed prompt = ", prompt)


                if test_inference:
                    # prompt = 'The highest mountain in the world is '
                    inputs = tokenizer(prompt, return_tensors="pt")
                    input_ids = inputs["input_ids"].to(device)
                    # generate_params = {
                    #     "input_ids": input_ids,
                    #     "generation_config": generation_config,
                    #     "return_dict_in_generate": True,
                    #     "output_scores": True,
                    #     "max_new_tokens": max_new_tokens,
                    # }
                    # print("yesyesyes???")
                    with torch.no_grad():
                        if not compression:
                            generation_output = model.generate(
                                input_ids=input_ids,
                                generation_config=generation_config,
                                return_dict_in_generate=True,
                                output_scores=True,
                                max_new_tokens=max_new_tokens,
                            )
                        else:
                            generation_output = model.compression_generate(
                                input_ids=input_ids,
                                generation_config=generation_config,
                                return_dict_in_generate=True,
                                output_scores=True,
                                max_new_tokens=max_new_tokens,
                            )

                    s = generation_output.sequences[0]
                    output = tokenizer.decode(s)

                    # answer = output.split()[-1]
                    answer = output.split(':')[-1].strip()
                    # print("------------Model Output------------")
                    
                    # print("label = ", params['label_dict'][test_labels[i]])
                    # print("label = ", test_labels[i])
                    # print("generation_output = ", generation_output)
                    # print("output = ", output)
                    # if i > 20: 
                    #     exit()
                    if dataset in ['cb', 'rte', 'dbpedia', 'sst2', 'trec']:
                        if answer in params['inv_label_dict'].keys() and test_labels[i] == params['inv_label_dict'][answer]:
                            correct += 1
                    else:
                        if answer.startswith(test_labels[i]) or test_labels[i].startswith(answer):
                            correct += 1
        print("model = ", model)
        print("dataset = ", dataset)
        print("num_seeds = ", num_seeds)
        print("all_shots = ", all_shots)
        print("ACC = ", correct / samples_num)

    # if dataset == "slot-movies":
    else:
        params = {
            'dataset': dataset,
        }
        orig_train_sentences, orig_train_labels, orig_test_sentences, orig_test_labels = load_dataset(params)

    # train_sentences = orig_train_sentences[:3]
    # train_labels = orig_train_labels[:3]
    # test_sentences = orig_test_sentences[0]
    # test_labels = orig_test_labels[0]


        np.random.seed(num_seeds)
        # AgNews 7600
        few_shot = all_shots
        if subsample_test_set == 0:
            test_sentences, test_labels = orig_test_sentences, orig_test_labels
            samples_num = len(orig_test_labels)
        else:
            samples_num = min(subsample_test_set, len(orig_test_labels))
            test_sentences, test_labels = random_sampling(orig_test_sentences, orig_test_labels, samples_num)


        train_sentences, train_labels = random_sampling(orig_train_sentences, orig_train_labels, few_shot)
        print("----------Train-----------")
        print(train_sentences[:8],train_labels[:8])

 
        # easier_train_sentences, easier_train_labels, easier_idx = construct_easier_order(train_sentences, train_labels)
        # harder_train_sentences, harder_train_labels, harder_idx = construct_harder_order(train_sentences, train_labels)
        # adjusted_train_sentences = [easier_train_sentences, train_sentences, harder_train_sentences]
        # adjusted_train_labels = [easier_train_labels, train_labels, harder_train_labels]

        def generate_permutations(n):
            # 生成1到n的数字列表
            nums = list(range(1, n+1))
            # 使用itertools.permutations生成全排列
            all_permutations = []
            for i in range(1, n + 1):
                permutations = list(itertools.permutations(nums, i))
                all_permutations += permutations
            return all_permutations

        all_rank = generate_permutations(few_shot)
        print("all_rank = ", all_rank)
        adjusted_train_sentences = [[train_sentences[i - 1] for i in each_rank] for each_rank in all_rank]
        adjusted_train_labels = [[train_labels[i - 1] for i in each_rank] for each_rank in all_rank]



    
        # print('final_res = ', final_res)

        # origin_prefix = copy.deepcopy(params['prompt_prefix'])
        
        count = 0

        prompt_list = []
        for tmp_train_sentences, tmp_train_labels in zip(adjusted_train_sentences, adjusted_train_labels):
            # only_single_sentence = tmp_train_sentences[-1:]
            # only_single_label = tmp_train_labels[-1:]
            # current_rank = all_rank[count]
            # print("---------------------------------------")
            # print("current_rank = ", current_rank)


            correct_single = 0
            # print("prompt = ", prompt)
            # for single sentence
            prompt, _ = construct_prompt_without_test(params, tmp_train_sentences, tmp_train_labels, "")
            if add_final_article:
                prompt += "Article: "
            print("prompt = ", prompt)
            prompt_list.append(prompt)
        
        inputs = tokenizer(
            prompt_list,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        input_ids = inputs["input_ids"].to(device)
        # print("inputs = ", inputs)
        # print("input_ids = ", input_ids.size())
        # print(input_ids)
        with torch.no_grad():
            generation_output = model(all_input_ids=input_ids.unsqueeze(0), attention_mask=inputs["attention_mask"].to(device).unsqueeze(0))
            print(generation_output.logits)
            for each in generation_output.logits:
                print(each[0][0].item())

        exit()
        
            # inputs = tokenizer(prompt, return_tensors="pt")
            # input_ids = inputs["input_ids"].to(device)
            
            # for i, each_test in enumerate(test_sentences):
            #     # params['prompt_prefix'] = origin_prefix
            #     # print("origin_prefix = ", or)
            #     if test_inference:

            #         with torch.no_grad():
            #             generation_output = model.generate(
            #                 input_ids=input_ids,
            #                 generation_config=generation_config,
            #                 return_dict_in_generate=True,
            #                 output_scores=True,
            #                 max_new_tokens=max_new_tokens,
            #             )

            #         s = generation_output.sequences[0]
            #         s = s.masked_fill(s.eq(-99), 0)
            #         output = tokenizer.decode(s)
            #         answer = output.split(':')[-1].replace("</s>", '').strip()
                    
            #         if dataset in ['cb', 'rte', 'dbpedia', 'sst2', 'trec', 'agnews']:
            #             if answer in params['inv_label_dict'].keys() and test_labels[i] == params['inv_label_dict'][answer]:
            #                 correct_single += 1
            #         else:
            #             if answer.startswith(test_labels[i]) or test_labels[i].startswith(answer):
            #                 correct_single += 1
                # pass


            # print("model = ", model)
            # print("dataset = ", dataset)
            # print("num_seeds = ", num_seeds)
            # print("all_shots = ", all_shots)
            # # print("single_sentence = ", only_single_sentence)
            # print("tmp_rank = ", all_rank[count])
            # print("ACC_single = ", correct_single / samples_num)
            # count += 1
        







if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # required arguments
    parser.add_argument('--model', dest='model', action='store', required=True, help='name of model(s), e.g., GPT2-XL')
    parser.add_argument('--lora_weight', dest='lora_weight', action='store', required=False, default=None, help='name of model(s), e.g., GPT2-XL')
    parser.add_argument('--dataset', dest='dataset', action='store', required=True, help='name of dataset(s), e.g., agnews')
    parser.add_argument('--num_seeds', dest='num_seeds', action='store', required=True, help='num seeds for the training set', type=int)
    # other arguments
    parser.add_argument('--subsample_test_set', dest='subsample_test_set', action='store', required=False, type=int,
                        default=None, help='size of test set to use to speed up eval. None means using all test set')
    # parser.add_argument('--use_last_token_as_classification', dest='use_last_token_as_classification', action='store_const', const=True, default=False,
    #                     help='whether to test the sequence order of the model')
    # parser.add_argument('--use_last_2token_as_classification', dest='use_last_2token_as_classification', action='store_const', const=True, default=False,
    #                     help='whether to test the sequence order of the model')
    parser.add_argument('--add_final_article', dest='add_final_article', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    # parser.add_argument('--api_num_log_prob', dest='api_num_log_prob', action='store', required=False, type=int,
    #                     default=100, help='number of top tokens to ask for when querying the model. Capped at 100 for OpenAI GPT-3 API')
    # parser.add_argument('--bs', dest='bs', action='store', required=False, type=int, default=None,
    #                     help='batch size for model queries. For OpenAI API, capped at 20. For local running, set this to max out your GPU memory.')
    # flags
    # parser.add_argument('--use_saved_results', dest='use_saved_results', action='store_const', const=True, default=False,
    #                     help='whether to load the results from pickle files and not run the model')
    parser.add_argument('--compression', dest='compression', action='store_const', const=True, default=False,
                        help='whether to use the compression generation mode')

    parser.add_argument('--compression_without_input', dest='compression_without_input', action='store_const', const=True, default=False,
                        help='whether to use the compression generation mode')
    parser.add_argument('--use_partial_mask', dest='use_partial_mask', action='store_const', const=True, default=False,
                        help='whether to use the partial_mask')
    parser.add_argument('--with_prompt_text', dest='with_prompt_text', action='store_const', const=True, default=False,
                        help='whether to use the prompt text')
    parser.add_argument('--compression_without_prompt_text', dest='compression_without_prompt_text', action='store_const', const=True, default=False,
                        help='whether to use the prompt text')
    parser.add_argument('--with_sequence_order', dest='with_sequence_order', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    parser.add_argument('--compression_token_initialization', dest='compression_token_initialization', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    # parser.add_argument('--with_prompt_text', dest='with_prompt_text', action='store_const', const=True, default=False,
    #                     help='whether to use the prompt text')
    parser.add_argument('--compression_length', dest='compression_length', action='store', required=False, default=None, help='num of compression_tokens', type=int)
    parser.add_argument('--all_shots', dest='all_shots', action='store', required=True, help='num training examples to use', type=int)
    # compression_token_initialization

    args = parser.parse_args()
    args = vars(args)

    # simple processing
    # def convert_to_list(items, is_int=False):
    #     if is_int:
    #         return [int(s.strip()) for s in items.split(",")]
    #     else:
    #         return [s.strip() for s in items.split(",")]

    # args['models'] = convert_to_list(args['models'])
    # args['datasets'] = convert_to_list(args['datasets'])
    # args['all_shots'] = convert_to_list(args['all_shots'], is_int=True)

    main(**args)