from .customize import CustomizedDataset
import gzip, json, re
import numpy as np
import os
import torch
import sys
from .data_utils import load_dataset
from datasets import load_dataset as hf_load_dataset


class PKU_ALIGN(CustomizedDataset):
    def __init__(self, cfg):
        cfg.data_foldername = "PKU-Alignment"
        if cfg.file_name is None:
            raise NameError('Must have a file_name in the config file')
        else:
            print(f'Use {cfg.file_name} to indicate data')
        self.file_name = cfg.file_name
        data_converter = self.gen_data_converter()
        feature, label, index = load_dataset(cfg, data_converter = data_converter, data_loader = self.load_pku_align)


        super(PKU_ALIGN, self).__init__(feature, label, index=index, preprocess=None)

    def load_pku_align(self, file_name):
        dataset = hf_load_dataset(file_name)
        return dataset


    def gen_data_converter(self):
        if "Beaver" in self.file_name:
            is_beaver = True
        else:
            is_beaver = False

        def data_converter(data):
            feature = []
            label = [[], []]
            if is_beaver:
                # all the QAs
                feature = np.char.add("[Human:] ", data['train']['prompt'] + data['test']['prompt'])
                feature = np.char.add(feature, "\n[Assistant:] ")
                feature = np.char.add(feature, data['train']['response'] + data['test']['response']).tolist()
                label = [[],[],[]]
                label[0] = (1 - np.asarray(data['train']['is_safe'] + data['test']['is_safe'])).astype(int).tolist()
                label[1] = list(range(len(label[0]))) # raw_idx
                label[2] = data['train']['category'] + data['test']['category']

                # only answers
                feature += data['train']['response'] + data['test']['response']
                label[0] += label[0]
                label[1] += label[1]      
                label[2] += label[2]          
            else:
                # all the QAs
                feature = np.char.add("[Human:] ", data['train']['prompt'] + data['test']['prompt'] + data['train']['prompt'] + data['test']['prompt'])
                feature = np.char.add(feature, "\n[Assistant:] ")
                feature = np.char.add(feature, data['train']['response_0'] + data['test']['response_0'] + data['train']['response_1'] + data['test']['response_1']).tolist()
                label = [[],[]]
                label[0] = (1 - np.asarray(data['train']['is_response_0_safe'] + data['test']['is_response_0_safe'] + data['train']['is_response_1_safe'] + data['test']['is_response_1_safe'])).astype(int).tolist()
                label[1] = list(range(len(label[0]))) # raw_idx

                # only answers
                feature += data['train']['response_0'] + data['test']['response_0'] + data['train']['response_1'] + data['test']['response_1']
                label[0] += label[0]
                label[1] += label[1]
                import pdb
                pdb.set_trace()
            assert isinstance(feature, list)
            assert isinstance(label, list)
            index = range(len(feature))
            return feature, label, index
        return data_converter
