# wget https://storage.googleapis.com/kaggle-competitions-data/kaggle-v2/12500/1375107/bundle/archive.zip\?GoogleAccessId\=web-data@kaggle-161607.iam.gserviceaccount.com\&Expires\=1679711159\&Signature\=a5o9McQg3KOX7RujnP78ebF9gyTuA1E10gjWk0KlW5sDbrh%2BkVLOaUGajV8fYO%2FWtc%2FsdOGei18UzoQHhUW7A0lB%2ByETMZddqMSe0tJfcZfaNNLlLD%2F%2BvsxEJgSC%2BgUKNMRFAZWqPlYFSYrPfO%2BC%2B7sBcbZsvUDFpfbujoAysX1%2FuFbUp11mIWU8TbB9hgozNbLpWY6Z6GuxEEik5rgPP49X7jFPCZBOdUEXK%2BhkGbV299jlrt30lRT2raz%2FiwyvfvJi370D%2BtuO7a%2FWr9SQbsI0Tx3v5BSXw9pCwvQnLkFnDKBJIb64NpezbRki8ebMfwHsO5zOYaIhH8q%2B94p5tQ%3D%3D\&response-content-disposition\=attachment%3B+filename%3Djigsaw-unintended-bias-in-toxicity-classification.zip

# https://www.kaggle.com/competitions/jigsaw-unintended-bias-in-toxicity-classification/data

# from .data_utils import load_csv
# from .customize import CustomizedDataset

# class JigsawToxicity(CustomizedDataset):
#     def __init__(self, cfg, train=True):
#         if train:
#             path = cfg.data_root + 'train.csv'
#         else:
#             path = cfg.data_root + 'test.csv'
#         data = load_csv(path)
#         feature = data.comment_text.to_numpy()
#         label = data.target.to_numpy()
#         index = data.index.to_numpy()
#         super(JigsawToxicity, self).__init__(feature, label, index, preprocess=None)
#         self.raw_label = label
#         self.label = (label > 0.5).astype(int)
    
import os
from .data_utils import load_dataset, load_csv, print_samples
from .customize import CustomizedDataset
import numpy as np

class JigsawToxicity(CustomizedDataset):
    def __init__(self, cfg, train=True):
        print(f'The Jigsaw dataset must be pre-downloaded from: https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/data')

        cfg.data_foldername = cfg.data_root
        # train = 'train' if train else 'test'
        if cfg.file_name is None:
            cfg.file_name = 'train' if train else 'test'
        else:
            print(f'Use {cfg.file_name} to indicate data')
        self.file_name = cfg.file_name
        data_converter = self.gen_data_converter()
        feature, label, index = load_dataset(cfg, data_converter = data_converter, data_loader = load_csv)

        if cfg.file_name != 'all_data':
            self.raw_label = label
            label = (np.asarray(label) >= 0.5).astype(int)
        else:
            cfg.label = []
        # # visualize some samples
        # print_samples(cfg, feature, label, range(10))
        
        super(JigsawToxicity, self).__init__(feature, label, index=index, preprocess=None)
    
    def gen_data_converter(self):
        if self.file_name == 'all_data':
            keys = ['toxicity', 'severe_toxicity', 'obscene', 'sexual_explicit', 'identity_attack', 'insult', 'threat']
        else:
            keys = None
        def data_converter(data):
            data["comment_text"].fillna("", inplace=True)
            feature = data.comment_text.to_numpy().tolist()
            if keys:
                # label = {}
                label = ""
                for key in keys:
                    data[key].fillna(0.0, inplace=True)
                    # label[key] = data[key].to_numpy()
                    label = np.char.add(label, f'{key}: ')
                    label = np.char.add(label, data[key].to_numpy().astype(str))
                    label = np.char.add(label, ', ')
            else:
                label = data.target.to_numpy().astype(float)
            label = label.tolist()
            index = range(len(feature))
            assert isinstance(feature, list)
            assert isinstance(label, list)
            return feature, label, index
        return data_converter


