import torch
import torch.nn as nn
from collections import OrderedDict
from torchvision import models as tv_models
from .layers import LongRangeModulation

__all__ = ['LRMNet', 'alexnet_lrm1', 'alexnet_lrm2', 'alexnet_lrm3']
   
class LRMNet(nn.Module):
    def __init__(self, backbone, mod_connections, time_steps=2, img_size=224):
        super().__init__()
        self.time_steps = time_steps
        self.backbone = backbone
        
        # make sure layers after any feedback targets have inplace=False
        self.modify_relu_layers([targ for targ,src in mod_connections])
        
        mod_layers = OrderedDict([])
        for target_layer, source_layers in mod_connections:
            mod_layer = LongRangeModulation(backbone, target_layer, source_layers, img_size=img_size)
            mod_layers[mod_layer.name] = mod_layer
            
        self.lrm = nn.Sequential(mod_layers)
    
    def modify_relu_layers(self, target_layers):
        prev_layer = None
        for layer_name, module in self.backbone.named_modules():
            if prev_layer in target_layers:
                module.inplace = False
            prev_layer = layer_name
    
    def forward(self, x, drop_state=True, time_steps=None):
        # drop any stored feedback or skip inputs (from previous batch)
        if drop_state:
            for block in self.lrm.children(): 
                block.mod_inputs = {}
        
        # iterate over time_steps, constant input
        time_steps = self.time_steps if time_steps is None else time_steps
        for step in range(0, time_steps):
            out = self.backbone(x)
        return out
    
def alexnet_lrm1(weights=None):
    
    mod_connections = [ 
        ('features.8', ['classifier.6']),
        ('features.0', ['features.9']),
    ]
    backbone = tv_models.alexnet(weights=weights)

    model = LRMNet(backbone, mod_connections, time_steps=2, img_size=224)
    
    if weights is not None:        
        msg = model.load_state_dict(weights.get_state_dict(), strict=True)
        print(msg)
        
    return model

def alexnet_lrm2(weights=None):
    
    mod_connections = [ 
        ('features.0', ['features.9']),    # conv4 => conv1
        ('features.3', ['features.12']),   # conv5 => conv2
        ('features.8', ['classifier.6']),  # output => conv4
        ('features.10', ['classifier.6']), # output => conv5
    ]
    
    backbone = tv_models.alexnet(weights=weights)

    model = LRMNet(backbone, mod_connections, time_steps=2, img_size=224)
    
    if weights is not None:        
        msg = model.load_state_dict(weights.get_state_dict(), strict=True)
        print(msg)
        
    return model

def alexnet_lrm3(weights=None):
    
    mod_connections = [ 
        ('features.0', ['features.9']),    # conv4 => conv1
        ('features.3', ['features.12']),   # conv5 => conv2
        ('features.6', ['classifier.2']),   # fc6 => conv3
        # ('classifier.1', ['classifier.6']),  # output => fc6
        ('features.8', ['classifier.6']),  # output => conv4
        ('features.10', ['classifier.6']), # output => conv5
    ]
    backbone = tv_models.alexnet(weights=weights)

    model = LRMNet(backbone, mod_connections, time_steps=2, img_size=224)
    
    if weights is not None:        
        msg = model.load_state_dict(weights.get_state_dict(), strict=True)
        print(msg)
        
    return model


