from .comlib import *
from . import adversary_transform

class Serial_Dataset(Dataset):
    def __init__(self, data_set):
        self.data_set = data_set

    def __getitem__(self, item):
        return item,*self.data_set[item]
    
    def __len__(self):

        return len(self.data_set)

class Concat_Dataset(Dataset): #input datasets
    def __init__(self, data_sets:list):
        self.data_sets = data_sets
        self.dataset_num=len(data_sets)
        self.set_index=[i*torch.ones(len(data_sets[i])) for i in range(self.dataset_num)]
        self.set_index=torch.hstack(self.set_index).type(torch.int)
        # pdb.set_trace()

        self.local_index=[torch.arange(len(data_sets[i])) for i in range(self.dataset_num)]
        self.local_index=torch.hstack(self.local_index)

        self.lens=[len(data_sets[i]) for i in range(self.dataset_num)]

    def __getitem__(self, item):
        return self.data_sets[self.set_index[item]][self.local_index[item]]
    
    def __len__(self):
        return torch.sum(torch.tensor(self.lens))

class TransformedDataset(Dataset):
    def __init__(self, data, trans: adversary_transform.Transform):
        """
        input target
        :param data: raw dataset
        :param transform: 数据处理
        """

        self.rawdata = data
        self.trans = trans

    def __getitem__(self, item):
        rawinput, rawlabel = self.rawdata[item]
        # print(rawinput, rawlabel, self.trans.__class__.__name__,self.trans.transform.__name__)
        return self.trans(rawinput, rawlabel)

    def __len__(self):
        return len(self.rawdata)
    

class TransformedDatasetWithMask(Dataset):
    def __init__(self, data, trans: adversary_transform.Transform,mask):
        """
        input target
        :param data: raw dataset
        :param transform: 数据处理
        """

        self.rawdata = data
        self.trans = trans
        self.mask=mask

    def __getitem__(self, item):
        rawinput, rawlabel = self.rawdata[item]
        if self.mask[item]:
            return self.trans(rawinput, rawlabel)
        else:
            return rawinput, rawlabel
        # print(rawinput, rawlabel, self.trans.__class__.__name__,self.trans.transform.__name__)
        

    def __len__(self):
        return len(self.rawdata)


class WrapDataset(Dataset): #input datasets
    def __init__(self, data_set,wrapper):
        self.data_set = data_set
        self.wrapper=wrapper

    def __getitem__(self, item):
        return self.wrapper.wrap(*self.data_set[item])
    
    def __len__(self):
        return len(self.data_set)