from functools import partial

import torch
from torch.nn.utils.prune import remove
from bypass.core.activation import ActivationForBypass, ActivationForActivionChange, _is_channel_wise
#from bypass.core.detect import BypassDetector, remove_detector, register_detector
#from bypass.core.prune_torch import custom_structured_from_mask
from bypass.core.prune_depgraph import prune_tp, BypassActivationPruner
from bypass.utils import GammaSchedule, NORM_FUNCTIONS
from typing import Iterable

import copy
'''
Mixin classes for bypassing.

usage: 

given Model(not instantiated)

initantiate Model with

model=Model.__class__('mixed_model', (mixin, base), {})(model_args)
'''

def _prune_parameter_and_grad(weight, keep_idxs, pruning_dim):
        pruned_weight = torch.nn.Parameter(torch.index_select(weight, pruning_dim, torch.LongTensor(keep_idxs).to(weight.device).contiguous()))
        if weight.grad is not None:
            pruned_weight.grad = torch.index_select(weight.grad, pruning_dim, torch.LongTensor(keep_idxs).to(weight.device))
        return pruned_weight.to(weight.device)
class BaseMixin:
    default_device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    def detect_activations(self,input_shape,**detector_kwargs):
        # learnable activ들을 detect해서 아래 dict들 setup
        # bypass_layers: 각 activ의 name을 key로 하고 앞/뒤에 해당하는 W,A를 value로 하는 dict
        self.eval()
        self.detector=BypassDetector(**detector_kwargs)
        # gamma_dict : 각 activ의 name을 key로 하고, gamma_t를 value로 가지는 dict. train step마다 update될 예정
        dummy_input=torch.rand([1,*input_shape]) # size 1 batch
        register_detector(self,self.detector,parent_name=self._get_name())
        # self.input_shape=input_shape
        with torch.no_grad():
            self.forward(dummy_input)
        self.detector.summarize()

        self.bypass_layers={
            i:{'A':v['A_layer'],'D':k,
               'W':v['W_layer'],'bypass':False} 
            for i, (k,v) in enumerate(
                self.detector.detected.items()
                )
            }
        self.dependancy = self.detector.dependancy
        self.module_names={k:v.replace(type(self).__name__+'.','') for k,v in self.detector.module_names.items()}
        remove_detector(self)
        return None
    def set_norm_function(self,norm_type):
        # norm_type = self.configs.adw_norm_type
        norm_func = NORM_FUNCTIONS[norm_type]
        self.norm_function = lambda x : norm_func(x,list(range(x.dim())))
    def embed(self):
        return NotImplemented
    def projection(self):
        return NotImplemented
    def ADW_loss_single(self):
        return NotImplemented
    def set_gamma(self,gamma, gamma_steps={}):
        self.gamma_dict={}
        if isinstance(gamma,dict):
            self.gamma_dict={k:GammaSchedule(v) for k,v in gamma}
        elif isinstance(gamma,str):
            self.gamma_dict={k:GammaSchedule(gamma) for k in self.bypass_layers.keys()}
        elif isinstance(gamma,Iterable):
            for i, g in gamma:
                self.gamma_dict[i]=GammaSchedule(g)
        else:
            self.gamma_dict={k:GammaSchedule(gamma) for k in self.bypass_layers.keys()}

        for k, gamma_step in gamma_steps.items():
            self.gamma_dict[k].global_step = gamma_step

        # # self.gamma_list=[]
        # if isinstance(gamma,str):
        #     self.gamma_list=[GammaSchedule(gamma) for x in self.bypass_layers]
        # elif isinstance(gamma, Iterable):
        #     self.gamma_list=[GammaSchedule(g) for g in gamma]
        # else:
        #     self.gamma_list=[GammaSchedule(gamma) for x in self.bypass_layers]
    def set_epsilons(self,epsilons):
        # 각 layer마다 opt2의 threshold값인 epsilon이 필요한데 이 부분을 설정하는 함수.
        # ADW가 epsilon보다 작으면 projection
        self.epsilon_dict = self.build_hparam_dict(epsilons)
        return None
    def set_pruning_epsilons(self,epsilons):
        # 각 layer마다 prune할 threshold값인 epsilon이 필요한데 이 부분을 설정하는 함수.
        # delta가 epsilon보다 크면 prune
        self.pruning_epsilon_dict = self.build_hparam_dict(epsilons)
        return None
    def build_hparam_dict(self,epsilon):
        if isinstance(epsilon,dict):
            ret={}
            for k,v in epsilon.items():
                if isinstance(v,str):
                    ret[k]=self._process_str_hparam(v)
                else:
                    ret[k]=v
            return ret
        if isinstance(epsilon,str):
            return {k:self._process_str_hparam(epsilon) for k in self.bypass_layers.keys()}
        elif isinstance(epsilon,Iterable):
            ret={}
            for k , eps in zip(self.bypass_layers.keys(), epsilon):
                ret[k]=eps
            return ret
        else:
            return {k:epsilon for k in self.bypass_layers.keys()}
    def _process_str_hparam(self,item:str):
        assert isinstance(item,str), '_process_str_hparam only process str'
        if item.endswith('auto') or 'ratio' or 'GMM' in item:
            return item
        else:
            return eval(item)
    # @property
    def gamma(self,name):
        return self.gamma_dict[name]
    
    # @property
    def ADW_loss(self):
        ADW_loss_dict={}
        gamma_log_dict={}
        # ret=torch.tensor(0,dtype=torch.float32,requires_grad=True,device=self.default_device)
        ret = 0

        for k in self.bypass_layers.keys():
            if self.bypass_layers[k]['bypass']:
                loss=self.ADW_loss_single(k)
                gamma=+self.gamma_dict[k]()
                ret=ret+gamma*loss
                ADW_loss_dict[f'ADW{k}']=loss
                gamma_log_dict[f'gamma{k}']=gamma
        return ret, ADW_loss_dict, gamma_log_dict
    
    # @property
    def ADW_metrics(self):
        ADW_loss_dict={}
        for k in self.bypass_layers.keys():
            loss=self.ADW_loss_single(k)
            ADW_loss_dict[f'ADW{k}']=loss
        return ADW_loss_dict

        # bypassing_layer_indices = [k for k,layers in self.bypass_layers.items() if layers['bypass']]

        # ADW_losses=[self.ADW_loss_single(x) for x in bypassing_layer_indices]
        # ADW_loss_sum=sum([g*x for g,x in zip(self.gamma_dict,ADW_losses)])
        # return torch.tensor(0,dtype=torch.float32,requires_grad=True) +sum([self.ADW_loss_single(x) for x in bypassing_layer_indices])
class BypassMixin_D(BaseMixin):
    def embed(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        bypass_layer_group['D'].embed()
        bypass_layer_group['bypass']=True
        return None
    
    def projection(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        D_activ=bypass_layer_group['D']

        D_activ.proj()
        bypass_layer_group['bypass']=False
        return None

    def ADW_loss_single(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        D_activ=bypass_layer_group['D']

        # b_W=W_layer.bias
        D_vec=D_activ.delta
        # ADW_mat=torch.matmul(A_mat,(D_vec*W_mat.T).T)

        # return torch.linalg.matrix_norm(ADW_mat)
        return torch.sqrt(torch.mean(D_vec**2))
class BypassMixin(BaseMixin):# ActivationForBypass+(ADW=0)

    def embed(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        bypass_layer_group['D'].embed()
        bypass_layer_group['bypass']=True
        return None
    
    def ADW_loss_single(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        ret=[]
        for W_layer in bypass_layer_group['W']:
            for A_layer in bypass_layer_group['A']:
        # W_layer=bypass_layer_group['W']
                D_activ=bypass_layer_group['D']

                W_mat=W_layer.weight
                D_vec=D_activ.delta
                
                DW_mat=D_vec*W_mat.T
                if hasattr(A_layer,'weight'):
                    A_mat= A_layer.weight
                    if DW_mat.dim()  == 1:
                        broadcast = [None for _ in A_mat.shape]
                        broadcast[1] = 1             
                        DW_mat = A_mat * DW_mat[broadcast]
            # return torch.linalg.matrix_norm(DW_mat) #unnormalized
            # return torch.sqrt(torch.mean(DW_mat**2))
            # return torch.mean(DW_mat**2)
                ret.append(self.norm_function(DW_mat))
        return sum(ret)

    # def ADW_loss_single(self,idx):
    #     bypass_layer_group=self.bypass_layers[idx]
    #     A_layer=bypass_layer_group['A']
    #     W_layer=bypass_layer_group['W']
    #     D_activ=bypass_layer_group['D']

    #     A_mat=A_layer.weight
    #     # b_A=A_layer.bias
    #     W_mat=W_layer.weight
    #     # b_W=W_layer.bias
    #     D_vec=D_activ.delta

    #     ADW_mat=torch.matmul((A_mat * D_vec),W_mat)
    #     # ADW_mat=torch.matmul(A_mat,(D_vec*W_mat.T).T)

    #     # return torch.linalg.matrix_norm(ADW_mat)
    #     return torch.sqrt(torch.mean(ADW_mat**2))
    def pass_channel_wise_layer(self,bias_update,layer):
        # layer(x+b) = layer(x)+b 인 경우
        if isinstance(layer,torch.nn.modules.pooling._MaxPoolNd):
            return bias_update
        if isinstance(layer,torch.nn.modules.pooling._AdaptiveAvgPoolNd):
            return bias_update
            
        # layer(x+b) = layer(x)+layer(b) 인 경우
        if type(layer) in [torch.nn.ReLU,torch.nn.Flatten, torch.nn.Dropout]:
        # if isinstance(layer,torch.nn.ReLU) or isinstance(layer,torch.nn.Flatten):
            return bias_update
    def apply_bias_update(self,DbW,A_layer,D_activ,prune_indices,bias_passed):
        if isinstance(A_layer,torch.nn.Linear) or isinstance(A_layer,torch.nn.Conv2d):
            A_mat=A_layer.weight
            b_A=A_layer.bias
            with torch.no_grad():
            # bias_update for Bypassing
                if isinstance(A_layer,torch.nn.Linear):
                    bias_update= torch.matmul(A_mat,(DbW))
                elif isinstance(A_layer,torch.nn.Conv2d):
                    conv_Amat =  A_mat.sum(dim=3).sum(dim=2)
                    bias_update = torch.matmul(conv_Amat,(DbW))
                # bias_update for pruning
                if prune_indices is not None:
                    if isinstance(A_layer,torch.nn.Linear):
                        A_mat_prune = torch.index_select(A_mat,1,prune_indices)
                        prune_bias_update=torch.matmul(A_mat_prune,bias_passed)
                    elif isinstance(A_layer,torch.nn.Conv2d):
                        A_mat_prune = torch.index_select(A_mat,1,prune_indices)
                        conv_Amat_prune = A_mat_prune.sum(dim=3).sum(dim=2)
                        # bias_passed=D_activ.activation(torch.index_select(b_W,0,prune_indices))
                        prune_bias_update=torch.matmul(conv_Amat_prune,bias_passed)
                else:
                    prune_bias_update = None
            if b_A is not None:
                b_A.data=b_A.data + bias_update
            else:
                A_layer.bias=torch.nn.Parameter(data=bias_update.data,requires_grad=True)

            if prune_bias_update is not None:
                b_A = A_layer.bias # 새로 생겼을 경우 다시 불러오기
                if b_A is not None:
                    b_A.data=b_A.data + prune_bias_update
                else:
                    A_layer.bias=torch.nn.Parameter(data=prune_bias_update.data,requires_grad=True)
            return bias_update, prune_bias_update
        else:
            new_DbW = self.pass_channel_wise_layer(DbW,A_layer)
            new_bias_passed = self.pass_channel_wise_layer(bias_passed,A_layer) if bias_passed is not None else None
            new_A_layer_cand = self.dependancy[A_layer]['next']
            if len(new_A_layer_cand) > 1:
                raise NotImplementedError
            new_A_layer = new_A_layer_cand[0]

            if isinstance(A_layer,torch.nn.Flatten):
                num_channels = DbW.shape[0]
                expand_dims = self.dependancy[A_layer]['prev'][0].output_size
                if prune_indices is not None:
                    dummy = torch.zeros([1,num_channels,*expand_dims],device=prune_indices.device)
                    dummy[:,prune_indices,:,:] =1
                    prune_indices = torch.where(dummy.flatten()==1)[0]
                if new_DbW is not None:
                    new_DbW = new_DbW.expand([*expand_dims,-1]).transpose(-1,0).flatten()
                if new_bias_passed is not None:
                    new_bias_passed=new_bias_passed.expand([*expand_dims,-1]).transpose(-1,0).flatten()

            
            return self.apply_bias_update(new_DbW,new_A_layer,D_activ,prune_indices,new_bias_passed)
    def projection(self,idx,prune_indices=None):
        bypass_layer_group=self.bypass_layers[idx]
        pruned_layers=[]
        D_activ=bypass_layer_group['D']
        A_layer=bypass_layer_group['A']
        W_layer=bypass_layer_group['W']
        if len(A_layer)+len(W_layer) != 2:
            return NotImplemented
        W_layer=bypass_layer_group['W'][0]

        A_layer=A_layer[0]
        prune_bias_update = None

        if W_layer.bias is not None:
            b_W=W_layer.bias
            D_vec=D_activ.delta
            Dbw=D_vec*b_W
            if prune_indices is not None:
                # bias_passed=D_activ.activation(torch.index_select(b_W,0,prune_indices))
                bias_passed = D_activ.skip_delta(b_W).index_select(0,prune_indices)
            else:
                bias_passed = None
        
            bias_update, prune_bias_update = self.apply_bias_update(Dbw,A_layer,D_activ,prune_indices,bias_passed)

            # update b_A for bypass
            
        D_activ.proj()
        if prune_indices is not None:
            pruned_layers = self.prune(idx,prune_indices)

            # update b_A for bypass
    
        # D_activ.proj(D_activ.num_parameters)
        bypass_layer_group['bypass']=False
        return pruned_layers
    def prune(self,idx,prune_indices,prune_verbose=False):
        
        D_layer=self.bypass_layers[idx]['D']
        pruned_layers=[]
        # backward pruning (Prune W)
        out_channel_prune_target=D_layer
        while True:
            target_cand = self.dependancy[out_channel_prune_target]['prev']
            if len(target_cand) > 1:
                raise NotImplementedError
            target=target_cand[0]
            prune_tp(target,prune_indices,'out')
            pruned_layers.append((0,self.module_names[target]))

            if not _is_channel_wise(target):
                break
            else:
                out_channel_prune_target = target
        
        # prune activation
        # BypassActivationPruner().prune_in_channels(layer=D_layer,idxs=prune_indices.cpu().numpy())
        # D_layer.register_buffer('out_mask',prune_indices)
        num_current_channels=D_layer.num_parameters
        prune_tp(D_layer,prune_indices)
        pruned_layers.append((0,self.module_names[D_layer]))
        
        if hasattr(D_layer.activation,'weight') and (D_layer.activation.weight is not None):
            if _is_channel_wise(D_layer):
                prune_tp(D_layer.activation,prune_indices,target='out')
                pruned_layers.append((0,self.module_names[D_layer]))
            else:
                prune_tp(D_layer.activation,prune_indices,target='in')
                pruned_layers.append((1,self.module_names[D_layer]))
                return None
        
        # forward pruning (Prune A)
        in_channel_prune_target = D_layer
        while True:
            target_cand = self.dependancy[in_channel_prune_target]['next']
            if len(target_cand) > 1:
                raise NotImplementedError
            if isinstance(in_channel_prune_target,torch.nn.Flatten):
                # num_channels = D_layer.num_parameters
                expand_dims= self.dependancy[in_channel_prune_target]['prev'][0].output_size
                dummy = torch.zeros([1,num_current_channels,*expand_dims],device=prune_indices.device)
                dummy[:,prune_indices,:,:] =1
                prune_indices = torch.where(dummy.flatten()==1)[0]
            target=target_cand[0]
            prune_tp(target,prune_indices,'in')
            pruned_layers.append((1,self.module_names[target]))
            if not _is_channel_wise(target):
                break
            else:
                in_channel_prune_target = target
        return pruned_layers
    
class MildPruningMixin_A(BypassMixin): # ActivationForBypass+(AD=0)
    def embed(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        bypass_layer_group['D'].embed()
        bypass_layer_group['bypass']=True
        return None
    
    def projection(self,idx):
        #TODO: add pruning here
        bypass_layer_group=self.bypass_layers[idx]
        D_activ=bypass_layer_group['D']
        D_activ.proj(D_activ.num_parameters)
        bypass_layer_group['bypass']=False
        return None

    def ADW_loss_single(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        A_layer=bypass_layer_group['A']
        D_activ=bypass_layer_group['D']

        A_mat=A_layer.weight
        D_vec=D_activ.delta

        AD_mat=A_mat*D_vec
        # return torch.linalg.matrix_norm(AD_mat) #unnormalized
        return torch.sqrt(torch.mean(AD_mat**2))
    
class MildPruningMixin_W(BypassMixin): # ActivationForBypass+(DW=0)
    def embed(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        bypass_layer_group['D'].embed()
        bypass_layer_group['bypass']=True
        return None


    def ADW_loss_single(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        ret=[]
        for W_layer in bypass_layer_group['W']:

        # W_layer=bypass_layer_group['W']
            D_activ=bypass_layer_group['D']

            W_mat=W_layer.weight
            D_vec=D_activ.delta

            DW_mat=D_vec*W_mat.T
        # return torch.linalg.matrix_norm(DW_mat) #unnormalized
        # return torch.sqrt(torch.mean(DW_mat**2))
        # return torch.mean(DW_mat**2)
            ret.append(self.norm_function(DW_mat))
        return sum(ret)

class ActivationChangeMixin(BaseMixin): # ActivationForActivationChange + (AD_2=0)
    def embed(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        bypass_layer_group['D'].embed()
        bypass_layer_group['bypass']=True
        return None
    
    def projection(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        A_layer=bypass_layer_group['A']
        D_activ=bypass_layer_group['D']

        A_mat=A_layer.weight
        D_vec=D_activ.delta2

        A_mat.data=A_mat.data * D_vec

        D_activ.proj()
        bypass_layer_group['bypass']=False
        return NotImplemented

    def ADW_loss_single(self,idx):
        bypass_layer_group=self.bypass_layers[idx]
        A_layer=bypass_layer_group['A']
        D_activ=bypass_layer_group['D']

        A_mat=A_layer.weight
        D_vec=D_activ.delta1
        AD_mat=A_mat*D_vec
        # return torch.linalg.matrix_norm(AD_mat) #unnormalized
        return torch.sqrt(torch.mean(AD_mat**2))

AVAILABLE_MIXIN={
    'base':BaseMixin, # for debug
    'bypass':BypassMixin,
    'bypass_D':BypassMixin_D,
    'mild_pruning_A':MildPruningMixin_A,
    'mild_pruning_W':MildPruningMixin_W,
    'activ_change':ActivationChangeMixin
}

def model_factory(base,bypass_type,input_shape,verbose=False,norm_type='l2mean',*model_args,**model_kwargs):
    mixin=AVAILABLE_MIXIN[bypass_type]
    name=f'{base.__name__}_{bypass_type}'
    # model=base.__class__(name, (mixin, base), {})(*model_args,**model_kwargs)
    model=type(name, (mixin, base), {})(*model_args,**model_kwargs)
    model.eval()
    model.detect_activations(input_shape,verbose=verbose)
    model.set_norm_function(norm_type)
    return model
if __name__ == '__main__':
    from bypass.core.models import cifar10_BypassBNresnet56,cifar10_NaiveBypassBNresnet56

    model=model_factory(cifar10_NaiveBypassBNresnet56,'mild_pruning_W',[3,32,32],verbose=False,norm_type='l1mean').cpu()
    # base=DNNMNIST
    # key='base'
    # mixin=AVAILABLE_MIXIN[key]
    # name=f'{base.__name__}_{key}'

    # model=base.__class__(name, (mixin, base), {})()
    # model.detect_activations([28,28])
    # import copy
    # count=0
    # for p in model.buffers():
    #     try:
    #         copy.deepcopy(p)
    #     except:
    #         print(count)
    #     count+=1
    # for name,mod in model.named_modules():
    #     # if mod == model:
    #     #     print(name)
    #     try:
    #         copy.deepcopy(mod)
    #     except:
    #         print('failed'+name)
    # copy.deepcopy(model)
    # model.embed(0)
    # model.ADW_loss
    model.set_gamma('0')
    ADW_loss=model.ADW_loss()
    model.projection(0)
    print(1)
