
from .comlib import *
from . import backdoor

class Transform():    
    def __repr__(self):
        class_name = self.__class__.__name__
        return f"{class_name}, {self.__dict__}"
    
class IdTransform(Transform):
    def __call__(self, *args):
        return args
class CompositeTransform(Transform):
    def __init__(self,transList):
        self.transList=transList
    def __call__(self, *args):
        r=args
        for trans in self.transList:
            r=trans(*r)
        return r
class InputTransform(Transform):
    def __init__(self,trans):
        self.trans=trans
    def __call__(self, *args):
        input,target= args
        return self.trans(input),target
    
class BackdoorTransform(Transform):
    def __init__(self,backdoor_pattern:backdoor.BackdoorPattern,backdoor_label=None):
        '''
        backdoor.BackdoorPattern(pattern_position=None,input_shape=task.MnistTask().input_shape,
                                        pattern_size=(5,10))
        '''
        self.backdoor_label=backdoor_label
        self.backdoor_pattern=backdoor_pattern
    def __call__(self, *args):
        input,target= args
        backdooredinput = self.backdoor_pattern.add(input)
        if self.backdoor_label is None:
            backdooredtarget= (target+1)%10
        else:
            backdooredtarget= self.backdoor_label
            
        return backdooredinput, backdooredtarget
    
# stale
class WrongLabelTransform(Transform):
    def __init__(self,wrong_label=None):
        self.wrong_label=wrong_label

    def __call__(self, *args):
        input,target= args
        if self.wrong_label is None:
            wrongtarget= (target+1)%10
        else:
            wrongtarget= self.wrong_label
        return input, wrongtarget


# class WrongLabelTransform11(Transform):
#     def __init__(self,wrong_label=None,wrong_ratio=1):
#         self.wrong_label=wrong_label
#         self.wrong_ratio=wrong_ratio

#     def flipLabel(self,target):
#         if self.wrong_label is None:
#             wrongtarget= (target+1)%10
#         else:
#             wrongtarget= self.wrong_label
#         return wrongtarget
    
#     def __call__(self, *args):
#         input,target= args
#         mask= random.random()<self.wrong_ratio
#         if mask:
#             return input, self.flipLabel(target)
#         else:
#             return input,target

