from torch.nn.modules import Module
from torchvision.models.vgg import VGG,cfgs,make_layers,vgg19_bn,vgg19,VGG19_BN_Weights,VGG19_Weights
from torchvision.models.resnet import ResNet, resnet50, ResNet50_Weights, Bottleneck, BasicBlock
import torch
import torch.nn as nn
from typing import Type, Union,Optional,List, Callable

import sys
sys.path.append('/workspace/jaeheun_MildPruning/')
from bypass.core.detect import BypassModel
from bypass.core.activation import ActivationForBypass,ActivationForDx2
# from bypass.core.models.resnet_tiny import BasicBlock, Bottleneck, BypassBottleneck
class BypassBottleneck(Bottleneck):
    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Union[Module,None] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Union[Callable[..., Module],None] = None) -> None:
        super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
        self.relu1=ActivationForBypass(planes,torch.nn.ReLU())
        self.relu2=ActivationForBypass(planes,torch.nn.ReLU())
        self.sum =Sum()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out=self.sum(out,identity)
        # out += identity
        out = self.relu(out)

        return out
    
class BypassBNBottleneck(Bottleneck):
    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Union[Module,None] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Union[Callable[..., Module],None] = None) -> None:
        super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
        if self.downsample is not None:
            in_planes = self.downsample[0].in_channels
            expansion = self.downsample[0].out_channels // self.downsample[0].in_channels
            self.downsample = torch.nn.Sequential(
                nn.Conv2d(in_planes, in_planes * expansion,
                          kernel_size=1, stride=stride, bias=False),
                ActivationForBypass(in_planes * expansion, torch.nn.BatchNorm2d(in_planes * expansion)),
            )
            
        #self.bn1=ActivationForBypass(planes, nn.Sequential(self.bn1, torch.nn.ReLU()))
        #self.bn2=ActivationForBypass(planes, nn.Sequential(self.bn2, torch.nn.ReLU()))
        self.bn1=ActivationForBypass(planes,self.bn1)
        self.bn2=ActivationForBypass(planes,self.bn2)
        self.bn3=ActivationForBypass(planes*4,self.bn3)
        
        self.relu1=torch.nn.ReLU(inplace=True)
        self.relu2=torch.nn.ReLU(inplace=True)
        
        self.sum =Sum()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out=self.sum(out,identity)
        # out += identity
        out = self.relu(out)

        return out
    
    
class OriginalBottleneck(Bottleneck):
    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Union[Module,None] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Union[Callable[..., Module],None] = None) -> None:
        super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
        #self.relu1=torch.nn.ReLU(inplace=True)
        #self.relu2=torch.nn.ReLU(inplace=True)
        self.sum =Sum()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out=self.sum(out,identity)
        # out += identity
        out = self.relu(out)

        return out
    
class Sum(torch.nn.Module):
    def forward(self,inp1,inp2,*args):
        return inp1+inp2
class BypassResnet_torchvision(ResNet):
    src_modules=[torch.nn.ReLU]
    bypass_wrapper=ActivationForBypass
    def __init__(self,
                 block: Type[Union[BasicBlock, Bottleneck]],
                layers: List[int],
                weights: Optional[ResNet50_Weights],
                progress: bool,
                num_classes:int,
                input_shape=(3,224,224),
                **kwargs):
        super().__init__(block=block, layers=layers,num_classes=num_classes, **kwargs)
        # inplace relu removal
        tgt=[]
        for name, module in self.named_modules():
            if isinstance(module,torch.nn.ReLU):
                tgt.append([name,module])
        for name, module in tgt:
            parent_module_name = '.'.join(name.split('.')[0:-1])
            node_name =  name.split('.')[-1]
            parent_module = self.get_submodule(parent_module_name)
            parent_module.register_module(node_name,torch.nn.ReLU(False))

        if weights is not None:
            self.load_state_dict(weights.get_state_dict(progress=progress),strict=False)
        # self = BypassModel(self,input_shape = input_shape,src_modules=self.src_modules, bypass_wrapper=self.bypass_wrapper)
     
class OriResnet_torchvision(ResNet):
    def __init__(self,
                 block: Type[Union[BasicBlock, Bottleneck]],
                layers: List[int],
                weights: Optional[ResNet50_Weights],
                progress: bool,
                num_classes:int,
                input_shape=(3,224,224),
                **kwargs):
        super().__init__(block=block, layers=layers,num_classes=num_classes, **kwargs)

        if weights is not None:
            self.load_state_dict(weights.get_state_dict(progress=progress),strict=False)        

class BypassVGG_torchvision(VGG):
    src_modules=[torch.nn.ReLU]
    bypass_wrapper=ActivationForBypass
    def __init__(self, cfg:str, batch_norm=False,weights=None,progress:bool=True, num_classes=1000,input_shape=(3,224,224),**kwargs):
        features = make_layers(cfgs[cfg],batch_norm=batch_norm)

        # inplace relu removal
        tgt=[]
        for name, module in features.named_modules():
            if isinstance(module,torch.nn.ReLU):
                tgt.append([name,module])
        for name, module in tgt:
            parent_module_name = '.'.join(name.split('.')[0:-1])
            node_name =  name.split('.')[-1]
            parent_module = features.get_submodule(parent_module_name)
            parent_module.register_module(node_name,torch.nn.ReLU(False))

        
        super().__init__(features=features, num_classes=num_classes,init_weights=weights,**kwargs)
        if weights is not None:
            self.load_state_dict(weights.get_state_dict(progress=progress))

        bp_features = BypassModel(self.features,input_shape = input_shape,src_modules=self.src_modules, bypass_wrapper=self.bypass_wrapper)
        self.register_module('features',bp_features)
        self.flatten = torch.nn.Flatten()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

        
class imagenet_BypassResnet50(BypassResnet_torchvision):
    src_modules=[torch.nn.ReLU]
    bypass_wrapper=ActivationForBypass
    def __init__(self,weights=None,progress:bool=True, **kwargs):
        super().__init__(BypassBottleneck,[3,4,6,3],weights=weights,progress=progress,num_classes=1000,**kwargs)
        
class imagenet_BypassBNResnet50(BypassResnet_torchvision):
    src_modules=[torch.nn.ReLU]
    bypass_wrapper=ActivationForBypass
    def __init__(self,weights=None,progress:bool=True, **kwargs):
        super().__init__(BypassBNBottleneck,[3,4,6,3],weights=weights,progress=progress,num_classes=1000,**kwargs)
        
class imagenet_OriResnet50(OriResnet_torchvision):
    def __init__(self,weights=None,progress:bool=True, **kwargs):
        super().__init__(OriginalBottleneck,[3,4,6,3],weights=weights,progress=progress,num_classes=1000,**kwargs)        

        
class imagenet_BypassVGG19(BypassVGG_torchvision):
    src_modules=[torch.nn.ReLU]
    bypass_wrapper=ActivationForBypass
    def __init__(self,weights=None,progress:bool=True, **kwargs):
        super().__init__('E', batch_norm=True,weights=weights,progress=progress, num_classes=1000,**kwargs)

if __name__ == '__main__':
    weights = ResNet50_Weights.IMAGENET1K_V1
    model = imagenet_BypassResnet50(weights=weights)
    model2 = imagenet_BypassBNResnet50(weights=weights)
    # tgt=[]
    # for name, module in model.named_modules():
    #     if isinstance(module,torch.nn.ReLU):
    #         tgt.append([name,module])
    #         print(name)
    # for name, module in tgt:
    #     parent_module_name = '.'.join(name.split('.')[0:-1])
    #     node_name =  name.split('.')[-1]
    #     parent_module = model.get_submodule(parent_module_name)
    #     parent_module.register_module(node_name,torch.nn.ReLU(F

    # features = BypassModel(model.features,input_shape = (3,224,224))
    

    # model.register_module('features',features)
    print(1)


