from typing import Mapping, Union, Any

from functools import partial
import torch
from torch import Tensor
from torch.nn import Module, Parameter
from torch.nn import PReLU

from bypass.utils import CROSS_CHANNEL_LAYERS, CHANNEL_WISE_LAYERS, NORMALIZE_LAYERS
from bypass.core.heuristics import Dx2heuristics,ConstantMultipleNorm

from torch_pruning.ops import _CustomizedOp

class ActivationForBypass(Module): # Dx+\sigma(x)
    def __init__(self,num_params,activation:Module,device=None, dtype=None,channel_last=False):
        assert not hasattr(activation,'inplace') or not activation.inplace
        self.num_parameters = num_params
        Module.__init__(self)
        
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.delta=Parameter(torch.zeros(num_params,**self.factory_kwargs),requires_grad=False) #initialize with 0
        self.activation=activation
        # if hasattr(self.activation,'weight'):
        #     self.weight = self.activation.weight
        # if hasattr(self.activation,'bias'):
        #     self.bias =  self.activation.bias
        self.status = 0
        self.einsum_func = partial(torch.einsum,'b...c,c->b...c') if channel_last else partial(torch.einsum,'bc...,c->bc...')
        return None
    def __repr__(self):
        return f'{self.__class__}({self.activation})'
    def forward(self,x):
        add_term = self.einsum_func(x,self.delta)
        return add_term + self.activation(x)
        # reg_shape=x.shape[:2]
        # add_term=self.delta[:,None] * x.reshape(shape=[*reg_shape,-1])
        # return add_term.reshape(shape=x.shape) + self.activation(x) # x: channel_last 일 경우
    def skip_delta(self,x):
        return self.activation(x)
    def embed(self):
        self.delta.requires_grad=True
        self.status =1
    def proj(self,num_param_update=None):
        # delta값을 초기화. 다른 weight 변환을 먼저 한 후 마지막에 실행할 것.
        # 이 경우 self.activation한 것과 같은 결과
        if num_param_update is not None:
            self.num_parameters=num_param_update
        device=self.delta.device
        dtype=self.delta.dtype
        self.delta.data=torch.zeros(self.num_parameters,device=device,dtype=dtype)
        self.delta.requires_grad=False
        self.status =2
    def __getattr__(self, name: str):
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in _parameters:
                return _parameters[name]
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
                return _buffers[name]
        if '_modules' in self.__dict__:
            modules = self.__dict__['_modules']
            if name in modules:
                return modules[name]
            if 'activation' in modules:
                return getattr(self.activation, name)

        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, name))


class ActivationForActivionChange(Module): # D_1\sigma_1(x)+D_2\sigma_2(x)
    def __init__(self,num_params,source_activation:Module,target_activation:Module,device=None, dtype=None):
        assert not hasattr(source_activation,'inplace') or not source_activation.inplace
        assert not hasattr(target_activation,'inplace') or not target_activation.inplace
        self.num_params=num_params
        super(ActivationForActivionChange, self).__init__()
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.delta1=Parameter(torch.ones(num_params,**self.factory_kwargs),requires_grad=False) #initialize with I
        self.delta2=Parameter(torch.zeros(num_params,**self.factory_kwargs),requires_grad=False) #initialize with 0

        self.sigma1=source_activation
        self.sigma2=target_activation
        return None
    def forward(self,x):
        return self.delta1 * self.sigma1(x) + self.delta2 * self.sigma2(x) # x: channel_last일 경우
    def embed(self):
        self.delta1.requires_grad=True
        self.delta2.requires_grad=True
    
    def proj(self):
        # delta값을 초기화. 다른 weight 변환을 먼저 한 후 마지막에 실행할 것.
        # Activation change에서는 delta1과 delta2가 바뀐다
        self.delta1.data=torch.zeros(self.num_params,**self.factory_kwargs)
        self.delta1.requires_grad=False

        self.delta2.data=torch.ones(self.num_params,**self.factory_kwargs)
        self.delta2.requires_grad=False



class ActivationForDx2(ActivationForBypass):
    def __init__(self,num_params,activation:Module,device=None, dtype=None,channel_last=False):
        ActivationForBypass.__init__(self,num_params=num_params,activation=activation,device=device, dtype=dtype,channel_last=channel_last)

        self.delta2 = Parameter(torch.zeros_like(self.delta,**self.factory_kwargs),requires_grad=True)

        self.status = 0
        self.delta_diff=self._delta_diff_0
        return None
    def _delta_diff_0(self):
        return torch.zeros_like(self.delta)
        # return self.delta-self.delta2
    def _delta_diff_1(self):
        return self.delta-self.delta2
    def _delta_diff_2(self):
        return self.delta
    def skip_delta(self, x):
        if isinstance(self.activation,NORMALIZE_LAYERS):
            self.activation._check_input_dim = lambda x:True            

        assert self.status in [1,2]
        if x.dim() == 1:
            x =  x.unsqueeze(0)
        # TODO FLAG: x.view(*x.shape,1,1) was for BN2d. need to generalize
        ret= self.activation(x.view(*x.shape,1,1)).squeeze()
        if self.status ==  1:
            # reg_shape=x.shape[:2]
            # add_term=-self.delta2[:,None] * x.reshape(shape=[*reg_shape,-1])
            # return add_term.reshape(shape=x.shape) + self.activation(x)
            add_term = self.einsum_func(x,-self.delta2)
            ret = add_term.squeeze()+ ret
        return ret
    # def forward(self,x):
    #     reg_shape=x.shape[:2]
    #     add_term=self.delta_diff()[:,None] * x.reshape(shape=[*reg_shape,-1])
    #     return add_term.reshape(shape=x.shape) + self.activation(x) # x: channel_last 일 경우
    def forward(self, x):
        add_term = self.einsum_func(x,self.delta_diff())
        return add_term + self.activation(x)
    
    def embed(self):
        if self.status == 0: # delta1만 optimize
            self.delta2.requires_grad = True
            self.delta.requires_grad=True
            self.status = 1 
            self.delta_diff=self._delta_diff_1
        elif self.status == 1: # pruning 이후 delta2 reduce
            self.delta.data = -self.delta2.data
            del(self.delta2)
            self.delta.requires_grad = True
            self.status = 2
            self.delta_diff = self._delta_diff_2
    def proj(self, num_param_update=None):
        if self.status in [0,1]:
            self.embed()
        elif self.status == 2:
            self.status=0
            self.delta_diff=self._delta_diff_0
            return super().proj(num_param_update)
    def _init_delta(self,delta_zero:torch.Tensor):
        # delta_zero:torch.tensor = heuristic_fn(ref_tensor,is_W)
        tmp = self.delta_diff().clone()
        self.delta.data = delta_zero.clone()
        if not hasattr(self,'delta2'): # status 2 -> 1
            self.register_parameter('delta2',Parameter(torch.zeros_like(self.delta,**self.factory_kwargs),requires_grad=True))
            self.status = 1
            self.delta_diff=self._delta_diff_1
        self.delta2.data = delta_zero.clone() - tmp
        
        self.delta.requires_grad = True
        self.delta2.requires_grad = True
        return None

class TrivialActivationForDx2(ActivationForBypass):
    def __init__(self,num_params,activation:Module,device=None, dtype=None,channel_last=False):
        ActivationForBypass.__init__(self,num_params=num_params,activation=activation,device=device, dtype=dtype,channel_last=channel_last)

        self.delta2 = Parameter(torch.zeros_like(self.delta,**self.factory_kwargs),requires_grad=True)

        self.status = 0
        self.delta_diff=self._delta_diff_0
        return None
    def _delta_diff_0(self):
        return torch.zeros_like(self.delta)
        # return self.delta-self.delta2
    def _delta_diff_1(self):
        return self.delta-self.delta2
    def _delta_diff_2(self):
        return self.delta
    def skip_delta(self, x):
        if isinstance(self.activation,NORMALIZE_LAYERS):
            self.activation._check_input_dim = lambda x:True            

        assert self.status in [1,2]
        if x.dim() == 1:
            x =  x.unsqueeze(0)
        # TODO FLAG: x.view(*x.shape,1,1) was for BN2d. need to generalize
        ret= self.activation(x.view(*x.shape,1,1)).squeeze()
        if self.status ==  1:
            # reg_shape=x.shape[:2]
            # add_term=-self.delta2[:,None] * x.reshape(shape=[*reg_shape,-1])
            # return add_term.reshape(shape=x.shape) + self.activation(x)
            add_term = self.einsum_func(x,-self.delta2)
            ret = add_term.squeeze()+ ret
        return ret
    # def forward(self,x):
    #     reg_shape=x.shape[:2]
    #     add_term=self.delta_diff()[:,None] * x.reshape(shape=[*reg_shape,-1])
    #     return add_term.reshape(shape=x.shape) + self.activation(x) # x: channel_last 일 경우
    def forward(self, x):
        x=self.activation(x)
        x=self.einsum_func(x,self.delta_diff()+1)# +x
        return x 
    
    def embed(self):
        if self.status == 0: # delta1만 optimize
            self.delta2.requires_grad = True
            self.delta.requires_grad=True
            self.status = 1 
            self.delta_diff=self._delta_diff_1
        elif self.status == 1: # pruning 이후 delta2 reduce
            self.delta.data = -self.delta2.data
            del(self.delta2)
            self.delta.requires_grad = True
            self.status = 2
            self.delta_diff = self._delta_diff_2
    def proj(self, num_param_update=None):
        if self.status in [0,1]:
            self.embed()
        elif self.status == 2:
            self.status=0
            self.delta_diff=self._delta_diff_0
            return super().proj(num_param_update)
    def _init_delta(self,delta_zero:torch.Tensor):
        # delta_zero:torch.tensor = heuristic_fn(ref_tensor,is_W)
        tmp = self.delta_diff().clone()
        self.delta.data = delta_zero.clone()
        if not hasattr(self,'delta2'): # status 2 -> 1
            self.register_parameter('delta2',Parameter(torch.zeros_like(self.delta,**self.factory_kwargs),requires_grad=True))
            self.status = 1
            self.delta_diff=self._delta_diff_1
        self.delta2.data = delta_zero.clone() - tmp
        
        self.delta.requires_grad = True
        self.delta2.requires_grad = True
        return None      

class TrivialActivationForBypass(ActivationForBypass):
    '''
    Constrast to ActivationForBypass, it considers identity activation after the module and wraps it as an activation
    '''
    def forward(self, x):
        x=self.activation(x)
        #x=x+torch.einsum('bc...,c->bc...',x,self.delta)
        x=self.einsum_func(x,self.delta+1)
        return x
    def _init_delta(self,delta_zero:torch.Tensor):
        self.delta.data = delta_zero.clone()
        self.delta.requires_grad = True
BYPASS_LAYERS=[ActivationForBypass,ActivationForActivionChange,ActivationForDx2,TrivialActivationForDx2,TrivialActivationForBypass]
def _is_channel_wise(module:torch.nn.Module):
    if type(module) in CHANNEL_WISE_LAYERS:
        return True
    if type(module) in CROSS_CHANNEL_LAYERS:
        return False
    if type(module) in BYPASS_LAYERS:
        return _is_channel_wise(module.activation)
    if isinstance(module,torch.nn.Sequential):
        return torch.all([_is_channel_wise(x) for x in module])
    else:
        return NotImplemented
if __name__ == '__main__':
    activation=torch.nn.ReLU()
    activation2=torch.nn.Tanh()

    hidden_feature_size=10
    # test_activation=ActivationForBypass(hidden_feature_size,activation=activation)
    test_activation=ActivationForActivionChange(hidden_feature_size,activation,activation2)
    
    # unbatched
    dummy_tensor=torch.arange(hidden_feature_size,dtype=torch.float32)
    test_activation(dummy_tensor)
    # dummy_tensor=torch.ones([hidden_feature_size,32,32]) # convolution type, channel first
    # test_activation(dummy_tensor)
    dummy_tensor=torch.ones([32,32,hidden_feature_size]) # convolution type, channel last
    test_activation(dummy_tensor)

    batch_size=128
    dummy_tensor=torch.ones([batch_size,hidden_feature_size]) # fully connected type
    test_activation(dummy_tensor)
    # dummy_tensor=torch.ones([batch_size,hidden_feature_size,32,32]) # convolution type, channel first
    # test_activation(dummy_tensor)
    dummy_tensor=torch.ones([batch_size,32,32,hidden_feature_size]) # convolution type, channel last
    test_activation(dummy_tensor)

    test_activation(dummy_tensor)
    print(1)