from datasets import interleave_datasets, load_dataset, load_from_disk
import random
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer,set_seed,DataCollatorWithPadding,AutoModel
from coh.data.templates import (dialogue_template, summary_template, webgpt_template,
                                webgpt_tie_template)
class HH_rlhf():
    def __init__(self,path_train,path_test,tokenizer,args):
        self.args = args
        self.train = load_from_disk(path_train).shuffle(seed=self.args.set_seed).select(range(40000))
        self.test = load_from_disk(path_test).shuffle(seed=self.args.set_seed)
        self.tokenizer = tokenizer
        if self.args.test_data_path =='':
            # process and filter train dataset
            # print(self.train)
            self.train = self.train.map(self.process)
            self.train = self.train.filter(self.filter)
            print(self.train)
            # process and test dataset
            self.test = self.test.map(self.process)
            self.test = self.test.filter(self.filter)
            print(self.test)
    
    def inference_test(self,path):
        inference_test = load_dataset("json", data_files=path,cache_dir='./')
        print(inference_test)
        inference_test = inference_test.map(self.process)
        return inference_test['train']

    def process(self,sample):
        #text
        prefix = ""
        chosen = sample['chosen']
        rejected = sample['rejected']
        last_chosen_answer = chosen.split("Assistant: ")[-1]
        last_rejected_answer = rejected.split("Assistant: ")[-1]
        chosen_prefix = "Assistant: ".join(chosen.split("Assistant: ")[0:-1]) + "Assistant: "
        rejected_prefix = "Assistant: ".join(rejected.split("Assistant: ")[0:-1]) + "Assistant: "
        #tokens
        prefix_tokens = self.tokenizer.encode(prefix)
        chosen_prefix_tokens = self.tokenizer.encode(chosen_prefix)
        rejected_prefix_tokens = self.tokenizer.encode(rejected_prefix)
        last_chosen_answer_tokens = self.tokenizer.encode(last_chosen_answer)
        last_rejected_answer_tokens = self.tokenizer.encode(last_rejected_answer)
        # coh content
        format = random.choice(dialogue_template[:2])
        if format.endswith('{pos}') and '{neg}' in format:
            p1 = format.split('{neg}')[0]
            p2 = format.split('{neg}')[1].split('{pos}')[0][2:]
            p1_tokens = self.tokenizer.encode(p1)
            p2_tokens = self.tokenizer.encode(p2)

            pos_input_tokens = prefix_tokens + p2_tokens + chosen_prefix_tokens + last_chosen_answer_tokens
            neg_input_tokens = prefix_tokens + p1_tokens + rejected_prefix_tokens + last_rejected_answer_tokens

            generate_prefix_tokens = prefix_tokens + p2_tokens + chosen_prefix_tokens
            generate_attention_masks = [1] * len(generate_prefix_tokens)

            pos_loss_masks = [0] * len(prefix_tokens) + [0] * len(p2_tokens) + [0] * len(chosen_prefix_tokens) + [1] * len(last_chosen_answer_tokens)
            neg_loss_masks = [0] * len(prefix_tokens) + [0] * len(p1_tokens) + [0] * len(rejected_prefix_tokens) + [1] * len(last_rejected_answer_tokens)
            
        elif format.endswith('{neg}') and '{pos}' in format:
            p1 = format.split('{pos}')[0]
            p2 = format.split('{pos}')[1].split('{neg}')[0][2:]
            p1_tokens = self.tokenizer.encode(p1)
            p2_tokens = self.tokenizer.encode(p2)
            
            pos_input_tokens = prefix_tokens + p1_tokens + chosen_prefix_tokens + last_chosen_answer_tokens
            neg_input_tokens = prefix_tokens + p2_tokens + rejected_prefix_tokens + last_rejected_answer_tokens

            generate_prefix_tokens = prefix_tokens + p1_tokens + chosen_prefix_tokens
            generate_attention_masks = [1] * len(generate_prefix_tokens)

            pos_loss_masks = [0] * len(prefix_tokens) + [0] * len(p1_tokens) + [0] * len(chosen_prefix_tokens) + [1] * len(last_chosen_answer_tokens)
            neg_loss_masks = [0] * len(prefix_tokens) + [0] * len(p2_tokens) + [0] * len(rejected_prefix_tokens) + [1] * len(last_rejected_answer_tokens)
        else:
            print(format)

        # add eos token and other
        pos_input_tokens.append(self.tokenizer.eos_token_id)
        neg_input_tokens.append(self.tokenizer.eos_token_id)
        pos_loss_masks.append(1)
        neg_loss_masks.append(1)

        coh_pos_dic = {'input_ids': pos_input_tokens, 'masks': pos_loss_masks, "attention_masks": generate_attention_masks, 'prefix':generate_prefix_tokens , 'target':last_chosen_answer_tokens}
        coh_neg_dic = {'input_ids': neg_input_tokens, 'masks': neg_loss_masks, "attention_masks": generate_attention_masks, 'prefix':generate_prefix_tokens , 'target':last_chosen_answer_tokens}
        
        coh_dic = {"pos_dict":coh_pos_dic,"neg_dict":coh_neg_dic}
        # coa content
        pos_input_tokens = prefix_tokens + chosen_prefix_tokens + last_chosen_answer_tokens
        neg_input_tokens = prefix_tokens + rejected_prefix_tokens + last_rejected_answer_tokens
        #pos sample
        pos_good_loss_masks = [0] * len(prefix_tokens) + [0] * len(chosen_prefix_tokens) + [1] * len(last_chosen_answer_tokens)
        pos_bad_loss_masks = [0] * len(prefix_tokens) + [0] * len(chosen_prefix_tokens) + [0] * len(last_chosen_answer_tokens)
        #neg sample
        neg_good_loss_masks = [0] * len(prefix_tokens) + [0] * len(rejected_prefix_tokens) + [0] * len(last_rejected_answer_tokens)
        neg_bad_loss_masks = [0] * len(prefix_tokens) + [0] * len(rejected_prefix_tokens) + [1] * len(last_rejected_answer_tokens)

        generate_attention_masks = [1] * len(prefix_tokens) + [1] * len(chosen_prefix_tokens)

        # add eos token and other
        pos_input_tokens.append(self.tokenizer.eos_token_id)
        neg_input_tokens.append(self.tokenizer.eos_token_id)
        pos_good_loss_masks.append(1)
        pos_bad_loss_masks.append(0)
        neg_good_loss_masks.append(0)
        neg_bad_loss_masks.append(1)
        #dict
        pos_dict = {"attention_masks": generate_attention_masks, 'prefix':prefix_tokens + chosen_prefix_tokens, 'target':last_chosen_answer_tokens, 'input_ids':pos_input_tokens, 'good_loss_masks':pos_good_loss_masks, 'bad_loss_masks':pos_bad_loss_masks}
        neg_dict = {"attention_masks": generate_attention_masks, 'prefix':prefix_tokens + rejected_prefix_tokens, 'target':last_rejected_answer_tokens, 'input_ids':neg_input_tokens, 'good_loss_masks':neg_good_loss_masks, 'bad_loss_masks':neg_bad_loss_masks}
        
        coa_dic = {"pos_dict":pos_dict,"neg_dict":neg_dict}

        return {"coh_dic": coh_dic, "coa_dic":coa_dic, "chosen_prefix_tokens":chosen_prefix_tokens,"rejected_prefix_tokens":rejected_prefix_tokens}
    def filter(self,sample):
        if len(sample["coh_dic"]['pos_dict']["input_ids"]) > self.args.seq_length or len(sample["coh_dic"]['neg_dict']["input_ids"]) > self.args.seq_length:
            return False
        if sample["chosen_prefix_tokens"] != sample["rejected_prefix_tokens"]:
            return False
        return True


class Summary_dataset():
    def __init__(self,path,train_split,test_split,tokenizer,args):
        self.args = args
        self.train = load_dataset(path, 'comparisons', split=train_split, cache_dir='./').shuffle(seed=self.args.set_seed)
        self.test = load_dataset(path, 'comparisons', split=test_split, cache_dir='./').shuffle(seed=self.args.set_seed)
        self.tokenizer = tokenizer
        self.store_dic = {}
        if self.args.test_data_path =='':
            # process and filter train dataset
            self.train = self.train.map(self.process)
            self.train = self.train.filter(self.filter)
            print(self.train)
            # process and test dataset
            self.test = self.test.map(self.process)
            self.test = self.test.filter(self.filter)
            print(self.test)
    def inference_test(self,path):
        inference_test = load_dataset("json", data_files=path,cache_dir='./')
        print(inference_test)
        inference_test = inference_test.map(self.process)
        return inference_test['train']

    def process(self,sample):
        #text
        if sample['info']['post']:
            prefix = sample['info']['post'] + "\n\n" + "### Response: TL;DR: "
        else:
            prefix = sample['info']['article'] + "\n\n" + "### Response: TL;DR: "
        pos_ind = int(sample['choice'])
        neg_ind = 1 - pos_ind
        chosen = sample['summaries'][pos_ind]['text']
        rejected = sample['summaries'][neg_ind]['text']

        #tokens
        prefix_tokens = self.tokenizer.encode(prefix)
        chosen_tokens = self.tokenizer.encode(chosen)
        rejected_tokens = self.tokenizer.encode(rejected)
        # coh content
        format = random.choice(summary_template[:2])
        if format.endswith('{pos}') and '{neg}' in format:
            p1 = format.split('{neg}')[0]
            p2 = format.split('{neg}')[1].split('{pos}')[0][2:]
            p1_tokens = self.tokenizer.encode(p1)
            p2_tokens = self.tokenizer.encode(p2)

            pos_input_tokens = prefix_tokens + p2_tokens + chosen_tokens
            neg_input_tokens = prefix_tokens + p1_tokens + rejected_tokens

            generate_prefix_tokens = prefix_tokens + p2_tokens
            generate_attention_masks = [1] * len(generate_prefix_tokens)

            pos_loss_masks = [0] * len(prefix_tokens) + [0] * len(p2_tokens) + [1] * len(chosen_tokens)
            neg_loss_masks = [0] * len(prefix_tokens) + [0] * len(p1_tokens) + [1] * len(rejected_tokens)
            
        elif format.endswith('{neg}') and '{pos}' in format:
            p1 = format.split('{pos}')[0]
            p2 = format.split('{pos}')[1].split('{neg}')[0]
            p1_tokens = self.tokenizer.encode(p1)
            p2_tokens = self.tokenizer.encode(p2)


            pos_input_tokens = prefix_tokens + p1_tokens + chosen_tokens
            neg_input_tokens = prefix_tokens + p2_tokens + rejected_tokens

            generate_prefix_tokens = prefix_tokens + p1_tokens
            generate_attention_masks = [1] * len(generate_prefix_tokens)

            pos_loss_masks = [0] * len(prefix_tokens) + [0] * len(p1_tokens) + [1] * len(chosen_tokens)
            neg_loss_masks = [0] * len(prefix_tokens) + [0] * len(p2_tokens) + [1] * len(rejected_tokens)

        else:
            print(format)

        # add eos token and other
        pos_input_tokens.append(self.tokenizer.eos_token_id)
        neg_input_tokens.append(self.tokenizer.eos_token_id)
        pos_loss_masks.append(1)
        neg_loss_masks.append(1)

        coh_pos_dic = {'input_ids': pos_input_tokens, 'masks': pos_loss_masks, "attention_masks": generate_attention_masks, 'prefix':generate_prefix_tokens , 'target':chosen_tokens}
        coh_neg_dic = {'input_ids': neg_input_tokens, 'masks': neg_loss_masks, "attention_masks": generate_attention_masks, 'prefix':generate_prefix_tokens , 'target':chosen_tokens}
        
        coh_dic = {"pos_dict":coh_pos_dic,"neg_dict":coh_neg_dic}
        
        # coa content
        pos_input_tokens = prefix_tokens + chosen_tokens
        neg_input_tokens = prefix_tokens + rejected_tokens
        #pos sample
        pos_good_loss_masks = [0] * len(prefix_tokens) + [1] * len(chosen_tokens)
        pos_bad_loss_masks = [0] * len(prefix_tokens) + [0] * len(chosen_tokens)
        #neg sample
        neg_good_loss_masks = [0] * len(prefix_tokens) + [0] * len(rejected_tokens)
        neg_bad_loss_masks = [0] * len(prefix_tokens) + [1] * len(rejected_tokens)

        generate_attention_masks = [1] * len(prefix_tokens)

        # add eos token and other
        pos_input_tokens.append(self.tokenizer.eos_token_id)
        neg_input_tokens.append(self.tokenizer.eos_token_id)
        pos_good_loss_masks.append(1)
        pos_bad_loss_masks.append(0)
        neg_good_loss_masks.append(0)
        neg_bad_loss_masks.append(1)
        #dict
        pos_dict = {"attention_masks": generate_attention_masks, 'prefix':prefix_tokens, 'target':chosen_tokens, 'input_ids':pos_input_tokens, 'good_loss_masks':pos_good_loss_masks, 'bad_loss_masks':pos_bad_loss_masks}
        neg_dict = {"attention_masks": generate_attention_masks, 'prefix':prefix_tokens, 'target':rejected_tokens, 'input_ids':neg_input_tokens, 'good_loss_masks':neg_good_loss_masks, 'bad_loss_masks':neg_bad_loss_masks}
        
        coa_dic = {"pos_dict":pos_dict,"neg_dict":neg_dict}

        return {"coh_dic": coh_dic, "coa_dic":coa_dic}
    def filter(self,sample):
        if len(sample["coh_dic"]['pos_dict']["input_ids"]) > self.args.seq_length or len(sample["coh_dic"]['neg_dict']["input_ids"]) > self.args.seq_length:
            return False
        if self.store_dic.get(self.tokenizer.decode(sample["coa_dic"]["pos_dict"]["prefix"])) != None:
            return False
        self.store_dic[self.tokenizer.decode(sample["coa_dic"]["pos_dict"]["prefix"])] = 1
        return True

