
import random
import json
import torch
from datasets import interleave_datasets, load_dataset, load_from_disk
from ml_collections import ConfigDict
from coh.data.templates import (dialogue_template, summary_template, webgpt_template,
                                webgpt_tie_template)
from coh.data.all_dataset import HH_rlhf, Summary_dataset

def store_data(datas,path):
    temp={}
    with open(path, 'w', encoding = 'utf-8') as f:
        for i in datas:
            temp = {}
            for k in i.keys():
                if k not in ['coa_dic','coh_dic','chosen_prefix_tokens','rejected_prefix_tokens']:
                    temp[k] = i[k]
            f.write(json.dumps(temp) + "\n")
        


class HumanFeedbackDataset(object):
    """ Human feedback dataset
    """

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.seq_length = 512
        config.split = 'train'
        config.batch_size = 1
        config.weight = ""
        config.dataset_size = 1000
        config.train_method = None
        config.factor = 1
        
        # config.data_contain = ['rlhf','webgpt','summary']

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    @staticmethod
    def make_webgpt_test_set(test_size=0.1):
        webgpt_data = load_dataset('./webgpt_comparisons/webgpt_comparisons.py', split='train', cache_dir='./')
        webgpt_data_split = webgpt_data.train_test_split(test_size=test_size)
        return webgpt_data_split


    def __init__(self, config, tokenizer, webgpt_data):
        self.config = self.get_default_config(config)
        self.data_contain = config.data_contain.split(',')
        self._tokenizer = tokenizer
        self.data_list = []
        self.weight_list = []
        self.real_dataset = []
        # 1. webgpt data
        if "webgpt" in self.data_contain:
            print("########################## Load webgpt dataset ######################")
            webgpt_split = 'test' if self.config.split == 'validation' else self.config.split
            if webgpt_split == "train":
                d1 = webgpt_data[webgpt_split].select(range(self.config.dataset_size))
            else:
                d1 = webgpt_data[webgpt_split].select(range(int(self.config.dataset_size/self.config.factor)))
            self.data_list.append(d1)
        # 2. RLHF data
         # only train, test
        if "rlhf" in self.data_contain:
            print("########################## Load rlhf dataset ######################")
            rlhf_dataset = HH_rlhf("./rlhf_train","./rlhf_test",self._tokenizer,self.config)
            rlhf_split = 'test' if self.config.split == 'validation' else self.config.split
            if self.config.test_data_path != '':
                d2 = rlhf_dataset.inference_test(self.config.test_data_path)
            else:
                if rlhf_split == "train":
                    d2 = rlhf_dataset.train.select(range(self.config.dataset_size))
                    store_data(d2,f'./{self.data_contain[1]}_train_data')
                else:
                    # d2 = rlhf_dataset.test.select(range(int(self.config.dataset_size/self.config.factor)))
                    d2 = rlhf_dataset.test
                    store_data(d2,f'./{self.data_contain[1]}_test_all_data')
                    # print(aser)
            self.data_list.append(d2)
        # 3. summarize from feedback data
        # only train, validation
        if "summary" in self.data_contain:
            print("########################## Load summary dataset ######################")
            summary_dataset = Summary_dataset("./summarize_from_feedback/summarize_from_feedback.py","train","validation",self._tokenizer,self.config)
            sff_split = 'validation' if self.config.split == 'test' else self.config.split
            if self.config.test_data_path != '':
                d3 = summary_dataset.inference_test(self.config.test_data_path)
            else:
                if sff_split == "train":
                    d3 = summary_dataset.train
                    store_data(d3,f'./{self.data_contain[0]}_train_data')
                else:
                    d3 = summary_dataset.test
                    store_data(d3,f'./{self.data_contain[0]}_test_data')
            self.data_list.append(d3)
        if self.config.weight == "":
            for d in self.data_list:
                self.weight_list.append(len(d))
        else:
            p = [int(x) for x in self.config.weight.split(',')]

        self.weight_list = [w/sum(self.weight_list) for w in self.weight_list]
        self._dataset = interleave_datasets(self.data_list, self.weight_list,seed=42)

        if config.train_method != "coh":
            for sample in self._dataset:
                self.real_dataset.append(sample["coa_dic"]["pos_dict"])
                self.real_dataset.append(sample["coa_dic"]["neg_dict"])
        else:
            for sample in self._dataset:
                self.real_dataset.append(sample["coh_dic"]["pos_dict"])
                self.real_dataset.append(sample["coh_dic"]["neg_dict"])

        if self.config.split == 'train':
            random.shuffle(self.real_dataset)
        print(len(self.real_dataset))

    def __iter__(self):

        for sample in self.real_dataset:
            for k in sample.keys():
                sample[k] = torch.LongTensor(sample[k]).reshape(self.config.batch_size, -1)
            yield sample

    def __getstate__(self):
        return self.config, self.tokenizer

    def __setstate__(self, state):
        config, tokenizer = state
        self.__init__(config, tokenizer)

    @property
    def seq_length(self):
        return self.config.seq_length

    @property
    def dataset_size(self):
        return len(self.real_dataset)

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def dataset(self):
        return self.real_dataset

    @property
    def vocab_size(self):
        return len(self._tokenizer)
