from .comlib import *
from . import adversary_transform as at
from . import worker_indices as wi
from . import dataset_operation, worker_conf
from .. import util

class AdversaryConf(worker_conf.WorkerConf):
    def __init__(self,baseDataset,
                 baseIndicesGen:wi.BaseIndices,
                 transform:at.Transform):
        self.baseDataset=baseDataset
        self.baseIndicesGen=baseIndicesGen
        
        self.transform=transform

    def getDataset(self,worker_id):
        baseDataIndice=self.baseIndicesGen.getIndices(worker_id)
        baseDataset=torch.utils.data.Subset(self.baseDataset, baseDataIndice)
        dataset=dataset_operation.TransformedDataset(baseDataset,self.transform)
        return dataset

class NoisyConf(worker_conf.WorkerConf):
    def __init__(self,baseDataset,
                 baseIndicesGen:wi.BaseIndices,
                 transform:at.Transform,random_seed,get_wrong_ratio):
        self.baseDataset=baseDataset
        self.baseIndicesGen=baseIndicesGen
        self.transform=transform
        self.random_seed=random_seed
        self.get_wrong_ratio=get_wrong_ratio

    def getDataset(self,worker_id):
        baseDataIndice=self.baseIndicesGen.getIndices(worker_id)
        baseDataset=torch.utils.data.Subset(self.baseDataset, baseDataIndice)

        rs=util.get_random_state("test",0)
        mask=rs.rand(len(baseDataset))<self.get_wrong_ratio(worker_id)
        dataset=dataset_operation.TransformedDatasetWithMask(baseDataset,self.transform,mask)
        return dataset

class SubstitudeAdversary(Dataset):
    BASE=0
    ADVERSARY=1
    def __init__(self,baseDataset: Dataset,adversaryDataset: Dataset,random_seed):
        self.len=len(baseDataset)
        self.baseDataset=baseDataset
        self.adversaryDataset=adversaryDataset
        self.random_seed=random_seed

        self.baseIndicesPerClass=wi.SpecLabelNumIndices.getClassIndex(baseDataset)
        self.adversaryIndicesPerClass=wi.SpecLabelNumIndices.getClassIndex(adversaryDataset)
        for target in self.baseIndicesPerClass:
            self.whichDataset,self.localIndex=self.mixIndices(self.baseIndicesPerClass[target],self.adversaryIndicesPerClass[target])
            
    @staticmethod
    def mixIndices(random_seed,baseIndices,adversaryIndices):
        ad_data_num=len(adversaryIndices)
        data_num=len(baseIndices)
        if ad_data_num>data_num:
            whichDataset=[SubstitudeAdversary.ADVERSARY]*data_num
            localIndex=wi.IIDIndices(random_seed,data_num,indexNum=ad_data_num,replace=False)
        else:
            whichDataset=[SubstitudeAdversary.ADVERSARY]*ad_data_num+[SubstitudeAdversary.BASE]*(data_num-ad_data_num)
            localIndex0=wi.IIDIndices(random_seed,data_num,indexNum=data_num-ad_data_num,replace=False)
            localIndex1=range(ad_data_num)
            localIndex=localIndex1+localIndex0

        return whichDataset, localIndex

    def __getitem__(self, item):
        if self.whichDataset[item]==SubstitudeAdversary.BASE:
            input, target = self.baseDataset[self.localIndex[item]]
        else:
            input, target = self.adversaryDataset[self.baseOrAdversary[item]]
            
        return input, target
    def __len__(self):
        return self.len
    

class MixConf(worker_conf.WorkerConf):
    SUBSTITUDE=0
    CONCAT=1
    def __init__(self,random_seed,
                 normalConf:worker_conf.WorkerConf,
                 adversaryConf:worker_conf.WorkerConf,
                 mixRule):
        self.normalConf=normalConf
        self.adversaryConf=adversaryConf
        self.mixRule=mixRule
        self.random_seed=random_seed

    def getDataset(self,worker_id):
        normalDataset=self.normalConf.getDataset(worker_id)
        adversaryDataset=self.adversaryConf.getDataset(worker_id)
        # print("MixConf",len(normalDataset),len(normalDataset))
        if self.mixRule==self.SUBSTITUDE:
            dataset=SubstitudeAdversary(normalDataset,adversaryDataset,self.random_seed)
        if self.mixRule==self.CONCAT:
            dataset=dataset_operation.Concat_Dataset([normalDataset,adversaryDataset])

        return dataset
