# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen
#
# This version from DeLightCMU/CASD
#
# Many comments and type annotations added by Brad Ezard
# Brad's comments all start with #!
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import *

from nets.network import Network
from model.config import cfg

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
import torchvision.models as models


#! Custom module to extract feature maps for CASD-LW
class _VggLayerExtractor(nn.Module):
    def __init__(
        self,
        layers: nn.Sequential
    ):
        super().__init__()
        blocks = []
        current_block = []
        for layer in layers:
            if isinstance(layer, nn.MaxPool2d):
                blocks.append(nn.Sequential(*current_block))
                current_block = [layer]
            else:
                current_block.append(layer)
        self.blocks = nn.ModuleList(blocks)
        self.block_outputs = []  #! hacky place to store block outputs

    def forward(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:
        outputs = []
        for block in self.blocks:
            x = block(x)
            outputs.append(x)
        self.block_outputs.append(outputs)
        return x

    def clear_block_outputs(self):
        del self.block_outputs
        self.block_outputs = []
        

class MELM_vgg16(Network):
  def __init__(self):
    Network.__init__(self)
    self._feat_stride = [16, ]
    self._feat_compress = [1. / float(self._feat_stride[0]), ]
    self._net_conv_channels = 512
    self._fc7_channels = 4096

  def _init_head_tail(self) -> None:
    self.vgg = models.vgg16()  #! get the base model
    # Remove fc8
    #! Equivalent to self.vgg.classifier = self.vgg.classifier[:-1]
    self.vgg.classifier = nn.Sequential(*list(self.vgg.classifier._modules.values())[:-1])

    # Fix the layers before conv3:
    #! Equivalent to self.vgg.features[:10].parameters()
    for layer in range(10):
      for p in self.vgg.features[layer].parameters(): p.requires_grad = False

    # not using the last maxpool layer
    self._layers['head'] = nn.Sequential(*list(self.vgg.features._modules.values())[:-1])
    
    #! Wrapped the layers to extract the feature maps at the different blocks for CASD-LW which appears absent
    #! self._layers['head'] = _VggLayerExtractor(self._layers['head'])

    # ------- parallel Gpu-----------
    self._layers['head'] = torch.nn.DataParallel(self._layers['head'])

  def _image_to_headtest(self) -> torch.Tensor:  #! pass through the convolutional layers of vgg
    net_conv = self._layers['head'](self._image)
    return net_conv

  def _image_to_head(self) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
    #! for some reason this includes the test version above?
    #! During training returns a length 2 list of tensors, during eval returns a single tensor
    #! self._image appears to be a list in training, a tensor in eval
    if self.training:
      net_conv = []
      for i in range(2):
        net_conv.append(self._layers['head'](self._image[i]))
    else:
      net_conv = self._layers['head'](self._image)
    return net_conv

  def _head_to_tail(self, pool5) -> torch.Tensor:  #! pass through the FC "neck" of vgg
    pool5_flat = pool5.view(pool5.size(0), -1)
    fc7 = self.vgg.classifier(pool5_flat)
    self._predictions['fc7'] = fc7

    return fc7

  def load_pretrained_cnn(self, state_dict):
    self.vgg.load_state_dict({k:v for k,v in state_dict.items() if k in self.vgg.state_dict()})
    #self.vgg.classifier = torch.nn.DataParallel(self.vgg.classifier)


