from data_utils import load_dataset
from utils import construct_prompt, random_sampling, construct_prompt_without_test, construct_prompt_instance_prompt_text
import numpy as np
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, LlamaForCompressionCausalLM, AutoConfig
import argparse
from typing import Dict, Optional, Sequence
import random
# import deepcopy
def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    # model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    print("num_new_tokens = ", num_new_tokens)
    return num_new_tokens

def load_model_lora(base_model, device, lora_weights, compression, compression_length, use_partial_mask):
    config = AutoConfig.from_pretrained(
        base_model,
        cache_dir='.cache',
    )
    if compression:
        config.compression_size = compression_length
        config.use_partial_mask = use_partial_mask
        print('use_partial_mask = ', use_partial_mask)
        
    if device == "cuda":
        if compression:
            model = LlamaForCompressionCausalLM.from_pretrained(
                base_model,
                load_in_8bit=False,
                torch_dtype=torch.float16,
                device_map="auto",
                config=config,
            )
        else:
            model = LlamaForCausalLM.from_pretrained(
                base_model,
                load_in_8bit=False,
                torch_dtype=torch.float16,
                device_map="auto",
            )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            torch_dtype=torch.float16,
        )
    elif device == "mps":
        if compression:
            model = LlamaForCompressionCausalLM.from_pretrained(
                base_model,
                device_map={"": device},
                torch_dtype=torch.float16,
                config=config,
            )
        else:
            model = LlamaForCausalLM.from_pretrained(
                base_model,
                device_map={"": device},
                torch_dtype=torch.float16,
            )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
            torch_dtype=torch.float16,
        )
    else:
        if compression:
            model = LlamaForCompressionCausalLM.from_pretrained(
                base_model, device_map={"": device}, low_cpu_mem_usage=True,
                config=config,
            )
        else:
            model = LlamaForCausalLM.from_pretrained(
                base_model, device_map={"": device}, low_cpu_mem_usage=True
            )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
        )
    return model

def main(model, lora_weight, dataset, num_seeds, all_shots, subsample_test_set, compression, compression_length, use_partial_mask, compression_without_input, with_prompt_text, compression_without_prompt_text, with_sequence_order, compression_token_initialization, structural_compression):
    test_inference = True

    print("lora_weight = ", lora_weight)
    print("model = ", model)
    print("dataset = ", dataset)
    print("with_prompt_text = ", with_prompt_text)
    print("compression_without_prompt_text = ", compression_without_prompt_text)




    if test_inference:
        if torch.cuda.is_available():
            device = "cuda"
        else:
            device = "cpu"
        base_model = 'decapoda-research/llama-7b-hf' if lora_weight is not None else model
        # base_model = model

        

        # model = LlamaForCausalLM.from_pretrained(
        #     base_model, device_map={"": device}, low_cpu_mem_usage=True, cache_dir='.cache',
        # )
        # model = PeftModel.from_pretrained(
        #     model,
        #     lora_weights,
        #     device_map={"": device},
        # )
        tokenizer = LlamaTokenizer.from_pretrained('openlm-research/open_llama_3b')
        # tokenizer = LlamaTokenizer.from_pretrained('decapoda-research/llama-7b-hf')
        # if lora_weight is None:

        IGNORE_INDEX = -100
        DEFAULT_PAD_TOKEN = "[PAD]"
        DEFAULT_EOS_TOKEN = "</s>"
        DEFAULT_BOS_TOKEN = "<s>"
        DEFAULT_UNK_TOKEN = "<unk>"
        special_tokens_dict = dict()
        if tokenizer.pad_token is None:
            special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
        if tokenizer.eos_token is None:
            special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
        if tokenizer.bos_token is None:
            special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
        if tokenizer.unk_token is None:
            special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

        
        num_new_tokens = smart_tokenizer_and_embedding_resize(
            special_tokens_dict=special_tokens_dict,
            tokenizer=tokenizer,
            # model=model,
        )

        if lora_weight is not None:
            model = load_model_lora(base_model, device, lora_weight, compression, compression_length, use_partial_mask)
                        # unwind broken decapoda-research config
            # model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
            # model.config.bos_token_id = 1
            # model.config.eos_token_id = 2

            print("model = ", model)
        else:
            config = transformers.AutoConfig.from_pretrained(
                'openlm-research/open_llama_3b',
                cache_dir='.cache',
            )
            config.compression_token_initialization = compression_token_initialization
            config.structural_compression = structural_compression
            if compression_token_initialization:
                config.initialize_ids = tokenizer("Article: N/A \n\n Answer: N/A \n\n", return_tensors="pt")['input_ids'].tolist()
            if compression: 
                
                config.vocab_size += num_new_tokens
                config.compression_size = compression_length
                config.use_partial_mask = use_partial_mask
                print("use_partial_mask = ", use_partial_mask)

                model = LlamaForCompressionCausalLM.from_pretrained(
                    base_model,
                    config=config,
                    # load_in_8bit=load_8bit,
                    # tie_weights=True,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    cache_dir='.cache',
                )                
            else:
                model = LlamaForCausalLM.from_pretrained(
                    base_model,
                    config=config,
                    # load_in_8bit=load_8bit,
                    # tie_weights=True,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    cache_dir='.cache',
                )
            
        # num_new_tokens = smart_tokenizer_and_embedding_resize(
        #     special_tokens_dict=special_tokens_dict,
        #     tokenizer=tokenizer,
        #     # model=model,
        # )
        temperature=0.8
        top_p=0.75
        top_k=40
        num_beams=4
        max_new_tokens=2 if compression else 1
        generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            # use_cache=True,
        )

    # dataset_name = ['agnews', 'atis', 'cb', 'dbpedia', 'lama', 'rte', 'slot-movies', 'sst2', 'trec']

     


    # if dataset == 'atis':
    model.eval()


    
    # Different ids got different training set, so we can not just mix them up.
    if dataset == "lama":
        all_lamas = [1001,101,103,106,108,127,1303,131,136,1376,138,140,1412,159,17,176,178,19,190,20,264,27,276,279,30,31,36,361,364,37,39,407,413,449,463,47,495,527,530,740,937]
        all_params = []
        for which_lama in all_lamas:
            # p = deepcopy(default_params)
            p = {}
            p['dataset'] = f"lama_{which_lama}"
            all_params.append(p)
        
        correct = 0
        test_samples_num = 0
        samples_num = 0
        for param_index, params in enumerate(all_params):
            orig_train_sentences, orig_train_labels, orig_test_sentences, orig_test_labels = load_dataset(params)
            
            np.random.seed(num_seeds)
            # AgNews 7600
            # samples_num = subsample_test_set
            few_shot = all_shots
            if test_samples_num == 0:
                test_sentences, test_labels = orig_test_sentences, orig_test_labels
                samples_num += len(orig_test_labels)

            else:
                test_sentences, test_labels = random_sampling(orig_test_sentences, orig_test_labels, samples_num)

            train_sentences, train_labels = random_sampling(orig_train_sentences, orig_train_labels, few_shot)
            # print("----------Train-----------")
            # print(orig_train_sentences[:5], orig_train_labels[:5])
            # print("----------Test-----------")
            # print(test_sentences[0], test_labels[0])



            # print('----------Prompt--------------')

            for i, each_test in enumerate(test_sentences):
                # print("i = ", i, end='\r')
                prompt = construct_prompt(params, train_sentences, train_labels, each_test)
                # print("constructed prompt = ", prompt)


                if test_inference:
                    # prompt = 'The highest mountain in the world is '
                    inputs = tokenizer(prompt, return_tensors="pt")
                    input_ids = inputs["input_ids"].to(device)
                    # generate_params = {
                    #     "input_ids": input_ids,
                    #     "generation_config": generation_config,
                    #     "return_dict_in_generate": True,
                    #     "output_scores": True,
                    #     "max_new_tokens": max_new_tokens,
                    # }
                    # print("yesyesyes???")
                    with torch.no_grad():
                        if not compression:
                            generation_output = model.generate(
                                input_ids=input_ids,
                                generation_config=generation_config,
                                return_dict_in_generate=True,
                                output_scores=True,
                                max_new_tokens=max_new_tokens,
                            )
                        else:
                            generation_output = model.compression_generate(
                                input_ids=input_ids,
                                generation_config=generation_config,
                                return_dict_in_generate=True,
                                output_scores=True,
                                max_new_tokens=max_new_tokens,
                            )

                    s = generation_output.sequences[0]
                    output = tokenizer.decode(s)

                    # answer = output.split()[-1]
                    answer = output.split(':')[-1].strip()
                    # print("------------Model Output------------")
                    
                    # print("label = ", params['label_dict'][test_labels[i]])
                    # print("label = ", test_labels[i])
                    # print("generation_output = ", generation_output)
                    # print("output = ", output)
                    # if i > 20: 
                    #     exit()
                    if dataset in ['cb', 'rte', 'dbpedia', 'sst2', 'trec']:
                        if answer in params['inv_label_dict'].keys() and test_labels[i] == params['inv_label_dict'][answer]:
                            correct += 1
                    else:
                        if answer.startswith(test_labels[i]) or test_labels[i].startswith(answer):
                            correct += 1
        print("model = ", model)
        print("dataset = ", dataset)
        print("num_seeds = ", num_seeds)
        print("all_shots = ", all_shots)
        print("ACC = ", correct / samples_num)

    # if dataset == "slot-movies":
    else:
        params = {
            'dataset': dataset,
        }
        orig_train_sentences, orig_train_labels, orig_test_sentences, orig_test_labels = load_dataset(params)

    # train_sentences = orig_train_sentences[:3]
    # train_labels = orig_train_labels[:3]
    # test_sentences = orig_test_sentences[0]
    # test_labels = orig_test_labels[0]


        np.random.seed(num_seeds)
        # AgNews 7600
        samples_num = subsample_test_set
        few_shot = all_shots
        if samples_num == 0:
            test_sentences, test_labels = orig_test_sentences, orig_test_labels
            samples_num = len(orig_test_labels)
        else:
            test_sentences, test_labels = random_sampling(orig_test_sentences, orig_test_labels, samples_num)


        train_sentences, train_labels = random_sampling(orig_train_sentences, orig_train_labels, few_shot)
        print("----------Train-----------")
        print(train_sentences[:8],train_labels[:8])
        # print(train_sentences, train_labels)
        # exit()
        # print("----------Test-----------")
        # print(test_sentences[0], test_labels[0])

        if with_sequence_order:
            count = 5
        else:
            count = 1

        while count > 0:
        # print('----------Prompt--------------')
            correct = 0
            # with torch.autocast("cuda"):
            for i, each_test in enumerate(test_sentences):
                # print("i = ", i, end='\r')
                if compression_without_input:
                    # this is the test instance input
                    if compression_without_prompt_text:
                        prompt, input_prompt_text, instance_prompt = construct_prompt_instance_prompt_text(params, train_sentences, train_labels, each_test)
                        instance_prompt = input_prompt_text + instance_prompt
                    else:
                        prompt, instance_prompt = construct_prompt_without_test(params, train_sentences, train_labels, each_test)



                    # prompt = construct_prompt(params, train_sentences, train_labels, each_test)
                else:
                    prompt = construct_prompt(params, train_sentences, train_labels, each_test)
                
                # print("constructed prompt = ", prompt)


                if with_prompt_text:
                    prompt += "\n\nPlease read the instructions and examples provided carefully, summarize and analyze the information, and then generate a new answer for a new instance provided."
                    # print("prompt = ", prompt)

                if test_inference:
                    # prompt = 'The highest mountain in the world is '
                    inputs = tokenizer(prompt, return_tensors="pt")
                    # used to give a context to the answer.






                    # answer_prefix = tokenizer('Answer: ', return_tensors="pt")
                    # print("answer_prefix = ", answer_prefix)
                    # exit()
                    if compression:
                        # print("inputs = ", inputs)
                        # inputs['input_ids'] = [torch.cat((input_id, torch.Tensor([-99] * compression_length)), dim=0) for input_id in inputs['input_ids']]
                        if compression_without_input:
                            instance_input = tokenizer(instance_prompt, return_tensors="pt")
                            inputs['input_ids'] = torch.cat((inputs['input_ids'], torch.Tensor([-99] * compression_length).long().unsqueeze(0), instance_input['input_ids'][:, 1:]), dim=1)
                        else:
                            inputs['input_ids'] = torch.cat((inputs['input_ids'], torch.Tensor([-99] * compression_length).long().unsqueeze(0)), dim=1)

                        # inputs['input_ids'] = torch.cat((inputs['input_ids'], torch.Tensor([-99] * compression_length).long().unsqueeze(0), answer_prefix['input_ids'][:, 1:]), dim=1)
                        # inputs['attention_mask'] = torch.cat((inputs['attention_mask'], torch.Tensor([1] * (compression_length + len(answer_prefix['input_ids']) - 1)).long().unsqueeze(0)), dim=1)

                    # if compression:



                        # print("after compression tokens inputs = ", inputs)
                    input_ids = inputs["input_ids"].to(device)
                    # generate_params = {
                    #     "input_ids": input_ids,
                    #     "generation_config": generation_config,
                    #     "return_dict_in_generate": True,
                    #     "output_scores": True,
                    #     "max_new_tokens": max_new_tokens,
                    # }
                    # print("yesyesyes???")
                    with torch.no_grad():
                        if not compression:
                            generation_output = model.generate(
                                input_ids=input_ids,
                                generation_config=generation_config,
                                return_dict_in_generate=True,
                                output_scores=True,
                                max_new_tokens=max_new_tokens,
                            )
                        else:
                            assert compression_length is not None
                            # input_ids = [input_id + [-99] * compression_length for input_id in input_ids]
                            generation_output = model.compression_generate(
                                input_ids=input_ids,
                                generation_config=generation_config,
                                return_dict_in_generate=True,
                                output_scores=True,
                                max_new_tokens=max_new_tokens,
                            )



                    s = generation_output.sequences[0]
                    s = s.masked_fill(s.eq(-99), 0)
                    # exit()
                    output = tokenizer.decode(s)

                    # answer = output.split()[-1]
                    answer = output.split(':')[-1].replace("</s>", '').strip()
                    # if compression:
                    #     print("------------Model Output------------")
                        
                    #     print("s = ", s)
                    #     print("output = ", output)
                    #     print("answer = ", answer)
                    #     print("label = ", params['label_dict'][test_labels[i]])
                    #     print("label = ", test_labels[i])
                    #     if i > 20: 
                    #         exit()
                        # print("output = ", output)
                        # print("generation_output = ", generation_output)
                        # exit()
                    # exit()

                    
                    if dataset in ['cb', 'rte', 'dbpedia', 'sst2', 'trec', 'agnews']:
                        if answer in params['inv_label_dict'].keys() and test_labels[i] == params['inv_label_dict'][answer]:
                            correct += 1
                    else:
                        if answer.startswith(test_labels[i]) or test_labels[i].startswith(answer):
                            correct += 1
            print("model = ", model)
            print("dataset = ", dataset)
            print("num_seeds = ", num_seeds)
            print("all_shots = ", all_shots)
            print("count = ", count)
            print("ACC = ", correct / samples_num)
            count -= 1
            # change the order of the training sentences
            combined = list(zip(train_sentences, train_labels))
            random.shuffle(combined)
            train_sentences, train_labels = zip(*combined)










if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # required arguments
    parser.add_argument('--model', dest='model', action='store', required=True, help='name of model(s), e.g., GPT2-XL')
    parser.add_argument('--lora_weight', dest='lora_weight', action='store', required=False, default=None, help='name of model(s), e.g., GPT2-XL')
    parser.add_argument('--dataset', dest='dataset', action='store', required=True, help='name of dataset(s), e.g., agnews')
    parser.add_argument('--num_seeds', dest='num_seeds', action='store', required=True, help='num seeds for the training set', type=int)
    # other arguments
    parser.add_argument('--subsample_test_set', dest='subsample_test_set', action='store', required=False, type=int,
                        default=None, help='size of test set to use to speed up eval. None means using all test set')
    # parser.add_argument('--api_num_log_prob', dest='api_num_log_prob', action='store', required=False, type=int,
    #                     default=100, help='number of top tokens to ask for when querying the model. Capped at 100 for OpenAI GPT-3 API')
    # parser.add_argument('--bs', dest='bs', action='store', required=False, type=int, default=None,
    #                     help='batch size for model queries. For OpenAI API, capped at 20. For local running, set this to max out your GPU memory.')
    # flags
    # parser.add_argument('--use_saved_results', dest='use_saved_results', action='store_const', const=True, default=False,
    #                     help='whether to load the results from pickle files and not run the model')
    parser.add_argument('--compression', dest='compression', action='store_const', const=True, default=False,
                        help='whether to use the compression generation mode')

    parser.add_argument('--compression_without_input', dest='compression_without_input', action='store_const', const=True, default=False,
                        help='whether to use the compression generation mode')
    parser.add_argument('--use_partial_mask', dest='use_partial_mask', action='store_const', const=True, default=False,
                        help='whether to use the partial_mask')
    parser.add_argument('--with_prompt_text', dest='with_prompt_text', action='store_const', const=True, default=False,
                        help='whether to use the prompt text')
    parser.add_argument('--compression_without_prompt_text', dest='compression_without_prompt_text', action='store_const', const=True, default=False,
                        help='whether to use the prompt text')
    parser.add_argument('--with_sequence_order', dest='with_sequence_order', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    parser.add_argument('--compression_token_initialization', dest='compression_token_initialization', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    parser.add_argument('--structural_compression', dest='structural_compression', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    # parser.add_argument('--with_prompt_text', dest='with_prompt_text', action='store_const', const=True, default=False,
    #                     help='whether to use the prompt text')
    parser.add_argument('--compression_length', dest='compression_length', action='store', required=False, default=None, help='num of compression_tokens', type=int)
    parser.add_argument('--all_shots', dest='all_shots', action='store', required=True, help='num training examples to use', type=int)
    # compression_token_initialization

    args = parser.parse_args()
    args = vars(args)

    # simple processing
    # def convert_to_list(items, is_int=False):
    #     if is_int:
    #         return [int(s.strip()) for s in items.split(",")]
    #     else:
    #         return [s.strip() for s in items.split(",")]

    # args['models'] = convert_to_list(args['models'])
    # args['datasets'] = convert_to_list(args['datasets'])
    # args['all_shots'] = convert_to_list(args['all_shots'], is_int=True)

    main(**args)