from .customize import CustomizedDataset
import gzip, json, re
import numpy as np
import os
import torch
import sys
from .data_utils import load_dataset


class HH_RLHF_new(CustomizedDataset):
    def __init__(self, cfg):
        cfg.data_foldername = cfg.data_root
        # train = 'train' if train else 'test'
        if cfg.file_name is None:
            raise NameError('Must have a file_name in the config file')
        else:
            print(f'Use {cfg.file_name} to indicate data')
        self.file_name = cfg.file_name
        data_converter = self.gen_data_converter()
        feature, label, index = load_dataset(cfg, data_converter = data_converter, data_loader = self.load_pt)


        super(HH_RLHF_new, self).__init__(feature, label, index=index, preprocess=None)

    def load_pt(self, file_name):
        if not file_name.endswith(".pt"):
            file_name += ".pt"
        return torch.load(file_name)

    def gen_data_converter(self):
        if self.file_name == 'anthropic_red_team_raw':
            is_pair = False
        else:
            is_pair = True
        

        def data_converter(data):
            feature = []
            label = [[], []]
            if is_pair: # higher is harmful
                data_feature = data['feature'] + data['label']
                data_label = [0] * len(data['feature']) + [1] * len(data['label'])
            else:
                data_feature = data['feature']
                data_label = data['label']
            # each Q&A stack as a feature. 
            for i in range(len(data_feature)):
                sample = data_feature[i]
                stack_1 = [ ' '.join(['[Human:]', sample['Human:'][i], '[Assistant:]', sample['Assistant:'][i]]) for i in range(min(len(sample['Human:']), len(sample['Assistant:']))) ]
                whole_stack = [" ".join(stack_1)]
                pure_answer = sample['Assistant:']

                if is_pair:
                    feature +=  whole_stack + pure_answer
                else:
                    feature +=  stack_1 + whole_stack
                label[0] += [data_label[i]] * (len(feature) - len(label[0]))
                label[1] += [i] * (len(feature) - len(label[1]))  # original index   
                # print(i, len(feature))   
            print(f"use new data")
            assert isinstance(feature, list)
            assert isinstance(label, list)
            index = range(len(feature))
            return feature, label, index
        return data_converter
