import copy
import torch
import torch.nn as nn
from attention import Attention

def transform_layer(conv, load_filters=True, stride_one=False, **kwargs):
    attn = Attention(in_dim=conv.in_channels, out_dim=conv.out_channels,
                     kernel_size=conv.kernel_size[0],
                     stride= (1 if stride_one else conv.stride[0]),
                     padding=conv.padding[0],
                     groups=conv.groups,
                     bias = True if conv.bias is not None else False,
                     **kwargs)
    if load_filters:
        attn.load_filters(conv)
    return attn

def transform_block(conv_block, model_name='ResNet', **kwargs):

    new_block = copy.deepcopy(conv_block)

    if model_name=='RegNet':
        attn = transform_layer(conv_block.conv2.conv, **kwargs)
        setattr(new_block.conv2, 'conv', attn)
    elif model_name=='ResNet' or model_name=='SENet' or model_name=='NormFreeNet':
        for child_name, child in new_block.named_children():
            if 'stride_one' in kwargs and kwargs['stride_one'] and 'downsample' in child_name:
                for grandchild_name, grandchild in child.named_modules():
                    if grandchild.__class__==nn.Conv2d:
                        grandchild.stride=(1,1)
                    elif grandchild.__class__==nn.AvgPool2d: #ResNet-D architecture
                        setattr(child, grandchild_name, nn.Identity())
            if 'conv' in child.__class__.__name__.lower():
                if child.kernel_size[0]==1:
                    continue
                attn = transform_layer(child, **kwargs)
                setattr(new_block, child_name, attn)
    else:
        raise

    return new_block

def Cnn2Transformer(model, first_attn_layer=4, **kwargs):        

    print('Transforming model')
    
    model_name = model.__class__.__name__
    new_model = copy.deepcopy(model)
    setattr(new_model, 'first_attn_layer', first_attn_layer)

    if model_name.startswith('NormFree'):
        layers = [(name, child) for (name, child) in model.stages.named_children()]
    else:
        layers = [(name, child) for (name, child) in model.named_children() if name.startswith('layer') or child.__class__.__name__=='RegStage']
        
    for layer_num, (child_name, child) in enumerate(layers):

        if layer_num < first_attn_layer-1 : continue
        new_layers = []
        for conv_block in child.children():
            new_layers.append(transform_block(conv_block, model_name, **kwargs))     
        setattr(new_model, child_name, nn.Sequential(*new_layers))
            
    return new_model

if __name__ == '__main__':
    from timm import create_model
    cnn = create_model('resnet18')
    hybrid = Cnn2Transformer(cnn)
    x = torch.randn(1,3,224,224)
    diff = cnn(x)-hybrid(x)
    print('Difference between CNN and T-CNN:', diff.norm().item())
