#!/usr/bin/env python
# coding: utf-8

# Here we demo how to construct an LRM network manually, and load our lrm3, lrm2, and lrm1 alexnet models. 

# In[1]:


import torch
import models
from models import LRMNet
from torchvision.models import alexnet


# ## manually construct alexnet_lrm network

# In[2]:


# load backbone
backbone = alexnet()

# Feedback is specified as a tuple with a 'target_layer' and list of [source_layers]
# - features.8 recieves feedback from classifier.6
# - features.0 receives feedback from features.9

mod_connections = [ 
    ('features.8', ['classifier.6']),
    ('features.0', ['features.9']),
]

# create the model with default number of forward passes (time_steps) and expected img_size
# - default number of forward passes can be overridden in forward pass
# - actual input img_size can be any size and LRMNet will adapt feedback size automatically
model = LRMNet(backbone, mod_connections, time_steps=2, img_size=224)
model


# In[3]:


x = torch.rand(5,3,224,224)
model.eval()
with torch.no_grad():
    out = model(x)
out.shape


# ## lrm3

# In[4]:


model = models.alexnet_lrm3()
model


# In[5]:


x = torch.rand(5,3,224,224)
model.eval()
with torch.no_grad():
    out = model(x)
out.shape


# ## lrm2

# In[6]:


model = models.alexnet_lrm2()
model


# In[7]:


x = torch.rand(5,3,224,224)
model.eval()
with torch.no_grad():
    out = model(x)
out.shape


# ## lrm1

# In[8]:


model = models.alexnet_lrm1()
model


# In[9]:


x = torch.rand(5,3,224,224)
model.eval()
with torch.no_grad():
    out = model(x)
out.shape


# In[10]:


# load backbone
backbone = alexnet()

# Feedback is specified as a tuple with a 'target_layer' and list of [source_layers]
# - features.8 recieves feedback from classifier.6
# - features.0 receives feedback from features.9

mod_connections = [ 
    ('classifier.1', ['classifier.6']),
]

# create the model with default number of forward passes (time_steps) and expected img_size
# - default number of forward passes can be overridden in forward pass
# - actual input img_size can be any size and LRMNet will adapt feedback size automatically
model = LRMNet(backbone, mod_connections, time_steps=2, img_size=224)
model


# In[11]:


x = torch.rand(5,3,224,224)
model.eval()
with torch.no_grad():
    out = model(x)
out.shape


# In[ ]:




