import pickle
import random
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import os
import json
import argparse
from constants import *

class NewData(Dataset):
    def __init__(self, args, raw_data, split: str, soft_prompt_text: list, position: str,
                 invalid_ans="[invalid]", add_soft_prompts=False, 
                 in_middle=False, num_new_token=None, step_type_ids=None, 
                 step_type_predictor=None):
        super(NewData, self).__init__()

        self.args = args
        self.raw_data = raw_data
        self.position = position
        self.split = split
        self.soft_prompt_text = soft_prompt_text
        self.invalid_ans = invalid_ans
        self.add_soft_prompts = add_soft_prompts
        self.in_middle = in_middle
        self.step_type_ids = step_type_ids
        self.step_type_predictor = step_type_predictor
        self.num_new_token = num_new_token
        
        self.x = []
        self.y = []
        self.target = []
        self.type = []
        self.gt_answer = []   
        self.base_prepare_data()
        # assert len(self.x) == len(self.y)
    
    def base_prepare_data(self):
        #templete: Question: {question} Answer: {answer}
        #add special tokens before answer
        for i, d in enumerate(self.raw_data):
            if self.split == 'train':
                # inputs = d['all_inputs'] 
                model_answer = d['model_answer']
                question = d['question'][0]
                if len(model_answer)>50:
                    continue
                target = d['hd_target']
                answer_type = d['hd_bin']
                type_name = f'type{answer_type}'
                if self.add_soft_prompts:
                    if self.position == 'left':
                        sol = self.soft_prompt_text['prefix'] + self.soft_prompt_text[type_name]
                        sol += ' ' + model_answer
                    else:
                        sol = self.soft_prompt_text['prefix'] + ' ' + model_answer + ' ' + self.soft_prompt_text[type_name]
                else:
                    sol = model_answer
                    # sol = model_answer + ' ' + self.soft_prompt_text['anchor']
                    # sol = self.soft_prompt_text['anchor'] + ' ' + model_answer
                sol.strip()
                self.y.append(sol)
                self.target.append(target)
                self.type.append(answer_type)
                # x = BRIEF_INSTRUCTION + 'Question: ' + question + '\n'
                x = 'Question: ' + question + '\n'
            else:

                question = d['question'][0]
                gt_answer = d['gt_answer'][0]
                self.gt_answer.append(gt_answer)
                # x = 'Question: '+ question + '\n'
                x ='Question: ' + question + '\n'


            # x = BRIEF_INSTRUCTION + QA_TEMPLATE.format(question=question[0])
            if self.add_soft_prompts and self.split != 'train':
                x += self.soft_prompt_text['prefix']
            self.x.append(x)

    
    #TODO:extract answer function
            
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if self.split == 'train':
            return dict(x=self.x[idx], y=self.y[idx], target=self.target[idx], bin_type = self.type[idx])
        else:
            return dict(x=self.x[idx], gt_answer=self.gt_answer[idx])





class BaseData(Dataset):
    def __init__(self, args, raw_data, split: str, soft_prompt_text: list, position: str,
                 invalid_ans="[invalid]", add_soft_prompts=False, 
                 in_middle=False, num_new_token=None, step_type_ids=None, 
                 step_type_predictor=None):
        super(BaseData, self).__init__()

        self.args = args
        self.raw_data = raw_data
        self.position = position
        self.split = split
        self.soft_prompt_text = soft_prompt_text
        self.invalid_ans = invalid_ans
        self.add_soft_prompts = add_soft_prompts
        self.in_middle = in_middle
        self.step_type_ids = step_type_ids
        self.step_type_predictor = step_type_predictor
        self.num_new_token = num_new_token
        
        self.x = []
        self.y = []
        self.target = []
        self.type = []
        self.gt_answer = []   
        self.base_prepare_data()
        # assert len(self.x) == len(self.y)
    
    def base_prepare_data(self):
        #templete: Question: {question} Answer: {answer}
        #add special tokens before answer
        for i, d in enumerate(self.raw_data):
            if self.split == 'train':
                question = d['question'] 
                model_answer = d['model_answer']
                target = d['hd_target']
                support = d['support']
                answer_type = d['hd_bin']
                type_name = f'type{answer_type}'
                if self.add_soft_prompts:
                    if self.position == 'left':
                        sol = self.soft_prompt_text['prefix'] + self.soft_prompt_text[type_name]
                        sol += ' ' + model_answer
                    else:
                        sol = self.soft_prompt_text['prefix'] + ' ' + model_answer + ' ' + self.soft_prompt_text[type_name]
                else:
                    sol = model_answer
                    # sol = model_answer + ' ' + self.soft_prompt_text['anchor']
                sol.strip()
                self.y.append(sol)
                self.target.append(target)
                self.type.append(answer_type)
                if self.args.dataset == 'triviaqa_brief':
                    x = QA_TEMPLATE.format(question=question[0])
                elif self.args.dataset == 'sciq_brief':
                    x = SCIQ_TEMPLATE_FINAL.format(question=question[0], support=support[0])
                """
                format question
                train_data: x = Question: Consumption was the old name for which disease\n
                test_data: x = Question: Consumption was the old name for which disease\n <prefix_0> <prefix_1> <prefix_2>
                """
            else:
                if 'triviaqa_brief' in self.args.dataset:
                    question = d['question'] 
                    # gt_answer = d['correct_answer']
                    gt_answer = d['gt_answer']
                    # if self.add_soft_prompts:
                    #     sol = self.soft_prompt_text['prefix'] + ' ' + model_answer
                    # else:
                    #     sol = model_answer
                    # sol.strip()
                    # self.y.append(sol)
                    self.gt_answer.append(gt_answer)
                    # x = 'Question: '+ question + '\n'
                    x = QA_TEMPLATE.format(question=question[0])
                elif self.args.dataset == 'sciq_brief':
                    question = d['question'] 
                    support = d['support']
                    gt_answer = d['gt_answer']
                    self.gt_answer.append(gt_answer)
                    # x = 'Question: '+ question + '\n' + 'Support: ' + support + '\n'
                    x = 'Question: '+ question[0] + '\n' + 'Support: ' + support[0] + '\n' + 'Answer: ' 

            # x = BRIEF_INSTRUCTION + QA_TEMPLATE.format(question=question[0])
            if self.add_soft_prompts and self.split != 'train':
                x += self.soft_prompt_text['prefix']
            self.x.append(x)

    
    #TODO:extract answer function
            
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if self.split == 'train':
            return dict(x=self.x[idx], y=self.y[idx], target=self.target[idx], bin_type = self.type[idx])
        else:
            return dict(x=self.x[idx], gt_answer=self.gt_answer[idx])
    
class KshotDataset(Dataset):
    """Dataset for in-context learning."""

    def __init__(self, dataset: Dataset, demo_dataset: Dataset, k=4, 
                 demo_selection='uniform', 
                 selection_model=None, tokenizer=None, prompt_text=None, save_dir=None):
        super(KshotDataset, self).__init__()
        self.data = dataset
        self.demo_data = demo_dataset
        self.demo_embeddings = None
        self.k = k
        self.selection_model = selection_model
        self.tokenizer = tokenizer
        self.prompt_text = prompt_text
        self.save_dir = save_dir
        self.sorted_demo_data = None
        if demo_selection == 'uniform':
            self.selection_func = self.uniform_selection
        elif demo_selection == 'prompt':
            self.selection_func = self.prompt_selection
        elif demo_selection == 'similar':
            self.selection_func = self.similarity_selection
        else:
            raise NotImplementedError
        self.x = []
        self.y = []
        self.target = []
        
        for i in range(len(self.data)):
            demo_data = self.selection_func(i)
            demos_text = '\n\n'.join([d['x'] + d['y'] for d in demo_data])
            input_text = demos_text + '\n\n' + self.data[i]['x']
            output_text = self.data[i]['y']
            target = self.data[i]['target']
            self.x.append(input_text)
            self.y.append(output_text)
            self.target.append(target)

    def uniform_selection(self, index=None):
        rand_ids = random.sample(range(len(self.demo_data)), self.k)
        return [self.demo_data[i] for i in rand_ids]
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return dict(x=self.x[i], y=self.y[i], target=self.target[i])

def tag_data(args):
    type_token_list = []
    prompt_text = {'prefix': ''}
    for i in range(1, args.num_types_cali+1):
        type_token_list.append(f'type{i}')
    for k in type_token_list:
        prompt_text[k] = ''
    #Construct prompt text
    special_tokens_list = []
    initialize_words_list = []
    for k in prompt_text:
        text = ''
        if k == 'prefix':
            num_types = args.num_prefix
        else:
            num_types = args.num_cali
        for i in range(num_types):
            token_name = f'<{k}_{i}>'
            special_tokens_list.append(token_name)
            initialize_words_list.append(k)
            text += ' ' + token_name
        prompt_text[k] = text

    model_name = INDENTIFIER2NAME[args.model]
    data_path = os.path.join(args.dataset_path, args.dataset, model_name, 'hd_data_updated.pkl')
    save_path = os.path.join(args.dataset_path, args.dataset, model_name, f'prefix_{args.num_prefix}_calitypes_{args.num_types_cali}_calinum_{args.num_cali}/')
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    prompt_text_path = os.path.join(save_path, 'prompt_text.json')
    with open(data_path, 'rb') as f:
        data = pickle.load(f)
    #split dataset
    raw_dataset = {}
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
    raw_dataset['train'] = train_data
    raw_dataset['test'] = test_data

    #initialize dataset
    train_dataset = BaseData(raw_dataset['train'], 'train', prompt_text, add_soft_prompts=args.add_soft_prompt) #8872
    test_dataset = BaseData(raw_dataset['test'], 'test', prompt_text, add_soft_prompts=args.add_soft_prompt) #2219

    if args.k_shot:
        dataset = KshotDataset(train_dataset, train_dataset, k=4,) #k-shot setting

    with open(prompt_text_path, 'w') as f:
        json.dump(prompt_text, f, indent=4)

    print('Done')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=MODEL_IDENTIFIER)
    parser.add_argument('--dataset_path', type=str, default=DATASET_PATH)
    parser.add_argument('--num_types_cali', type=int, default=10)
    parser.add_argument('--num_prefix', type=int, default=3)
    parser.add_argument('--num_cali', type=int, default=3)
    parser.add_argument('--dataset', type=str, default='triviaqa')
    parser.add_argument('--add_soft_prompt', default=True)
    args = parser.parse_args()
   
    tag_data(args)