from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
from .cache_model import CacheMemory
from .cornet_rt import CORnet_RT_Coarselite

# from cache_model import CacheMemory
import pdb


HASH = '933c001c'

"""
This script is adapted from the cornet_rt (ganguli 2019), change the input size and kernel size, input is 32x32
not ImageNet.

this version is for direct ensemble learning of the two-passway model.
"""


class Flatten(nn.Module):

    """
    Helper module for flattening input tensor to 1-D for the use in Linear modules
    """

    def forward(self, x):
        return x.view(x.size(0), -1)


class Identity(nn.Module):

    """
    Helper module that stores the current tensor. Useful for accessing by name
    """

    def forward(self, x):
        return x


class CORblock_S(nn.Module):
    """
    This model is a small network block. In the block, it only contains one convoltional layer.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, out_shape=None):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.out_shape = out_shape

        self.conv_input = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                                    stride=stride, padding=kernel_size // 2)
        self.norm_input = nn.GroupNorm(32, out_channels)
        self.nonlin_input = nn.ReLU(inplace=True)

        self.output = Identity()  # for an easy access to this block's output

    def forward(self, inp=None, state=None,add_noise=False,noise_ratio=0.1,**kwargs):
        inp = self.conv_input(inp)
        inp = self.norm_input(inp)

        if add_noise:
            inp = inp + torch.rand_like(inp)*noise_ratio

        inp = self.nonlin_input(inp)

        if state is None:  # at t=0, state is initialized to 0
            state = 0
        skip = inp + state

        output = self.output(skip)
        return output

class CORnet_RT_FineliteB(nn.Module):
    """
    This lite FineNet

    feedback_mode: 1) "upsample_pconv", upsampling + point convolution
                   2) "dsconv_upsample_gate", point convolution + upsample + sigmoid
    """

    def __init__(self, num_classes=10,top_layer="IT",
                       low_layer="V2",in_channels=3,**kwargs):
        super().__init__()

        self.V1 = CORblock_S(in_channels, 64, kernel_size=3, stride=1, out_shape=32)
        self.V2 = CORblock_S(64,  128,  stride=2, out_shape=16)
        self.V4 = CORblock_S(128, 256,  stride=2, out_shape=8)
        self.IT = CORblock_S(256, 512,  stride=2, out_shape=4)

        self.top_layer = top_layer
        self.low_layer = low_layer

        self.pool = nn.Sequential(OrderedDict([
               ('avgpool', nn.AdaptiveAvgPool2d(1)),
               ('flatten', Flatten()),
        ]))

        self.classifier = nn.Linear(512,num_classes)

    def forward(self, inp, **kwargs):

        batch_size = inp.shape[0]

        v1_out = self.V1(inp)
        v2_out = self.V2(v1_out)
        v4_out = self.V4(v2_out)
        it_out = self.IT(v4_out)
        out_list = []
        out_pool = self.pool(it_out)

        out_list.append(it_out)
        out_list.append(out_pool)

        pred = self.classifier(out_pool)
        return pred,out_list  ###[out_pool1,pred1]


class CORnet_TwopassB(nn.Module):
    """
    Args:
    cache_model_dict: dict contains ["source_data","target_data","mode","T"]
    """

    def __init__(self,fine_model_dict=dict(),coarse_model_dict=dict(),cache_model_dict=dict(),
                      ensemble_mode="mode_a",num_classes=10,**kwargs):
        super().__init__()

        self.finenet = CORnet_RT_FineliteB(**fine_model_dict)
        self.coarsenet = CORnet_RT_Coarselite(**coarse_model_dict)

        self.ensemble_mode = ensemble_mode

        if self.ensemble_mode == "mode_a":
            self.ensemble_net = nn.Sequential(
                         nn.Conv2d(1024,1024,kernel_size=3,padding=1,groups=1024),
                         nn.GroupNorm(32, 1024),
                         nn.Conv2d(1024,512,kernel_size=1,bias=False,stride=1, padding=0),
                         nn.GroupNorm(32, 512),
                         nn.ReLU(inplace=True),
                         nn.AdaptiveAvgPool2d(1),
                         Flatten(),
                         )
        self.classifier = nn.Linear(512,num_classes)


    def forward(self,x,fine_training=True,coarse_training=True,pre_feats=False,**kwargs):
        fine_x, coarse_x = x
        batch_size = fine_x.shape[0]
        if coarse_training:
            coarse_pred, coarse_hids = self.coarsenet(coarse_x)
        else:
            with torch.no_grad():
                coarse_pred, coarse_hids = self.coarsenet(coarse_x)

        if fine_training:
            fine_pred, fine_hids = self.finenet(fine_x)
        else:
            with torch.no_grad():
                fine_pred, fine_hids = self.finenet(fine_x)

        conv_coarse = coarse_hids[-2]
        conv_fine = fine_hids[-2]

        ensemble_input = torch.cat([conv_fine,conv_coarse],dim=1)
        ensemble_out = self.ensemble_net(ensemble_input)
        pred = self.classifier(ensemble_out)

        if pre_feats:
            feats = (fine_hids,coarse_hids)
            return pred, fine_pred, coarse_pred,feats
        else:
            return pred, fine_pred, coarse_pred

###
class CORnet_TwopassC(nn.Module):
    """
    Args:
    cache_model_dict: dict contains ["source_data","target_data","mode","T"]
    """

    def __init__(self,fine_model_dict=dict(),coarse_model_dict=dict(),cache_model_dict=dict(),
                      ensemble_mode="mode_a",num_classes=10,**kwargs):
        super().__init__()

        self.finenet = CORnet_RT_FineliteB(**fine_model_dict)
        self.coarsenet = CORnet_RT_FineliteB(**coarse_model_dict)

        self.ensemble_mode = ensemble_mode

        if self.ensemble_mode == "mode_a":
            self.ensemble_net = nn.Sequential(
                         nn.Conv2d(1024,1024,kernel_size=3,padding=1,groups=1024),
                         nn.GroupNorm(32, 1024),
                         nn.Conv2d(1024,512,kernel_size=1,bias=False,stride=1, padding=0),
                         nn.GroupNorm(32, 512),
                         nn.ReLU(inplace=True),
                         nn.AdaptiveAvgPool2d(1),
                         Flatten(),
                         )
        self.classifier = nn.Linear(512,num_classes)


    def forward(self,x,fine_training=True,coarse_training=True,pre_feats=False,**kwargs):
        fine_x, coarse_x = x
        batch_size = fine_x.shape[0]
        if coarse_training:
            coarse_pred, coarse_hids = self.coarsenet(coarse_x)
        else:
            with torch.no_grad():
                coarse_pred, coarse_hids = self.coarsenet(coarse_x)

        if fine_training:
            fine_pred, fine_hids = self.finenet(fine_x)
        else:
            with torch.no_grad():
                fine_pred, fine_hids = self.finenet(fine_x)

        conv_coarse = coarse_hids[-2]
        conv_fine = fine_hids[-2]

        ensemble_input = torch.cat([conv_fine,conv_coarse],dim=1)
        ensemble_out = self.ensemble_net(ensemble_input)
        pred = self.classifier(ensemble_out)

        if pre_feats:
            feats = (fine_hids,coarse_hids)
            return pred, fine_pred, coarse_pred,feats
        else:
            return pred, fine_pred, coarse_pred



if __name__ == "__main__":
    import pdb
    # net = CORnet_RT_Fine(times=2,top_layer="IT",low_layer="V1")
    net = CORnet_TwopassB()
    inp = torch.ones(1,3,32,32)
    inp1 = torch.ones(1,1,32,32)
    pred, fine_pred, coarse_pred = net((inp,inp1))


    pdb.set_trace()







###
