from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
from .cache_model import CacheMemory

# from cache_model import CacheMemory

import pdb


HASH = '933c001c'

"""
This script is adapted from the cornet_rt (ganguli 2019), change the input size and kernel size, input is 32x32
not ImageNet.
"""


class Flatten(nn.Module):

    """
    Helper module for flattening input tensor to 1-D for the use in Linear modules
    """

    def forward(self, x):
        return x.view(x.size(0), -1)


class Identity(nn.Module):

    """
    Helper module that stores the current tensor. Useful for accessing by name
    """

    def forward(self, x):
        return x


class CORblock_RT(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, out_shape=None):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.out_shape = out_shape

        self.conv_input = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                                    stride=stride, padding=kernel_size // 2)
        self.norm_input = nn.GroupNorm(32, out_channels)
        self.nonlin_input = nn.ReLU(inplace=True)

        self.conv1 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, padding=1, bias=False)
        self.norm1 = nn.GroupNorm(32, out_channels)
        self.nonlin1 = nn.ReLU(inplace=True)

        self.output = Identity()  # for an easy access to this block's output

    def forward(self, inp=None, state=None, batch_size=None,add_noise=False,noise_ratio=0.1,**kwargs):
        if inp is None:  # at t=0, there is no input yet except to V1
            inp = torch.zeros([batch_size, self.out_channels, self.out_shape, self.out_shape])
            if self.conv_input.weight.is_cuda:
                inp = inp.cuda()
        else:
            inp = self.conv_input(inp)
            inp = self.norm_input(inp)
            if add_noise:
                inp = inp + torch.rand_like(inp)*noise_ratio
            inp = self.nonlin_input(inp)
        if state is None:  # at t=0, state is initialized to 0
            state = 0
        skip = inp + state

        x = self.conv1(skip)
        x = self.norm1(x)
        x = self.nonlin1(x)

        output = self.output(x)
        # output = state
        return output


class CORnet_RT_Fine(nn.Module):

    def __init__(self, times=5,num_classes=10,top_layer="IT",low_layer="V2",in_channels=3,**kwargs):
        super().__init__()
        self.times = times

        self.V1 = CORblock_RT(in_channels, 64, kernel_size=3, stride=1, out_shape=32)
        self.V2 = CORblock_RT(64,  128,  stride=2, out_shape=16)
        self.V4 = CORblock_RT(128, 256,  stride=2, out_shape=8)
        self.IT = CORblock_RT(256, 512,  stride=2, out_shape=4)

        self.top_layer = top_layer
        self.low_layer = low_layer

        self.pool = nn.Sequential(OrderedDict([
               ('avgpool', nn.AdaptiveAvgPool2d(1)),
               ('flatten', Flatten()),
        ]))

        filter_dict = dict(IT=512,V4=256,V2=128,V1=64)

        scale_factor = int(filter_dict[top_layer]/filter_dict[low_layer])
        if scale_factor == 1:
            self.fb_process = nn.Sequential(
                         Identity()
                         )
        else:
            self.fb_process = nn.Sequential(
                         nn.Upsample(scale_factor=scale_factor,mode='bilinear'),
                         nn.Conv2d(filter_dict[top_layer],filter_dict[low_layer],kernel_size=1,bias=False),
                         nn.GroupNorm(32,filter_dict[low_layer]),
                         nn.ReLU(inplace=True)
                         )
        self.classifier = nn.Linear(512,num_classes)

    def forward(self, inp, fast_feedback=None,feedback_time=0,add_noise=False,noise_ratio=0.1,pre_feats=False,**kwargs):

        batch_size = inp.shape[0]

        # pdb.set_trace()
        if fast_feedback is not None and feedback_time==0:
            fast_feedback = fast_feedback.reshape(-1,512,4,4)
            fd_inp = self.fb_process(fast_feedback)

        else:
            fd_inp = 0
        v1_out = self.V1(inp,fd_inp if self.low_layer=="V1" else 0, add_noise=add_noise and self.low_layer=="V1",noise_ratio=noise_ratio)
        v2_out = self.V2(v1_out,fd_inp if self.low_layer=="V2" else 0, add_noise=add_noise and self.low_layer=="V2",noise_ratio=noise_ratio)
        v4_out = self.V4(v2_out,fd_inp if self.low_layer=="V4" else 0, add_noise=add_noise and self.low_layer=="V4",noise_ratio=noise_ratio)
        it_out = self.IT(v4_out,fd_inp if self.low_layer=="IT" else 0, add_noise=add_noise and self.low_layer=="IT",noise_ratio=noise_ratio)

        it_out_list = []
        it_out_list.append(it_out.detach().view(batch_size,-1))

        for i in range(1,self.times):
            if self.top_layer == "IT":
                fd_inp = self.fb_process(it_out)
            elif self.top_layer == "V4":
                fd_inp = self.fb_process(v4_out)
            elif self.top_layer == "V2":
                fd_inp = self.fb_process(v2_out)
            elif self.top_layer == "V1":
                fd_inp = self.fb_process(v1_out)
            else:
                fd_inp = 0

            v1_out = self.V1(inp,fd_inp if self.low_layer=="V1" else 0)
            v2_out = self.V2(v1_out,fd_inp if self.low_layer=="V2" else 0)
            v4_out = self.V4(v2_out,fd_inp if self.low_layer=="V4" else 0)
            it_out = self.IT(v4_out,fd_inp if self.low_layer=="IT" else 0)

            it_out_list.append(it_out.detach().view(batch_size,-1))

        out_pool = self.pool(it_out)
        pred = self.classifier(out_pool)
        return pred,it_out_list+[out_pool,]



class CORblock_S(nn.Module):
    """
    This model is a small network block. In the block, it only contains one convoltional layer.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, out_shape=None):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.out_shape = out_shape

        self.conv_input = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                                    stride=stride, padding=kernel_size // 2)
        self.norm_input = nn.GroupNorm(32, out_channels)
        self.nonlin_input = nn.ReLU(inplace=True)

        self.output = Identity()  # for an easy access to this block's output

    def forward(self, inp=None, state=None,add_noise=False,noise_ratio=0.1,**kwargs):
        inp = self.conv_input(inp)
        inp = self.norm_input(inp)

        if add_noise:
            inp = inp + torch.rand_like(inp)*noise_ratio

        inp = self.nonlin_input(inp)

        if state is None:  # at t=0, state is initialized to 0
            state = 0
        skip = inp + state

        output = self.output(skip)
        return output

class CORnet_RT_Coarse(nn.Module):

    """
    Args,
       kernel_type: three kernel types here,"very_large","large","media","small" and "large_S1"
                    "large_S1" means only the front kernel size is large.
    """

    def __init__(self,in_channels=1,num_classes=10,kernel_type="large"):
        super().__init__()

        if kernel_type == "very_large":
            kernel_list = [9,7,5]
        if kernel_type == "large":
            kernel_list = [7,5,3]
        if kernel_type == "media":
            kernel_list = [5,3,3]
        if kernel_type == "small":
            kernel_list = [3,3,3]
        if kernel_type == "large_S1":
            kernel_list = [9,3,3]

        self.in_channels = in_channels
        self.num_classes = num_classes

        self.S1 = CORblock_S(in_channels,64,kernel_size=kernel_list[0],stride=2,out_shape=16)
        self.S2 = CORblock_S(64,128,kernel_size=kernel_list[1],stride=2,out_shape=8)
        self.S3 = CORblock_S(128,256,kernel_size=kernel_list[2],stride=2,out_shape=4)

        self.pool = nn.Sequential(OrderedDict([
               ('avgpool', nn.AdaptiveAvgPool2d(1)),
               ('flatten', Flatten()),
        ]))

        self.classifier = nn.Linear(256,num_classes)

    def forward(self,inp,**kwargs):
        x = self.S1(inp)
        x = self.S2(x)
        x = self.S3(x)

        pool = self.pool(x)
        pred = self.classifier(pool)

        return pred,[x,pool]



###
class CORnet_RT_Coarselite(nn.Module):

    """
    This CORnet_RT_Coarselite is lite version of coarsenet.
    Args,
       kernel_type: three kernel types here,"very_large","large","media","small" and "large_S1"
                    "large_S1" means only the front kernel size is large.
    """

    def __init__(self,in_channels=1,num_classes=10,kernel_type="large"):
        super().__init__()

        if kernel_type == "large":
            kernel_list = [7,5]
        if kernel_type == "small":
            kernel_list = [5,3]

        self.in_channels = in_channels
        self.num_classes = num_classes

        self.S1 = CORblock_S(in_channels,64,kernel_size=kernel_list[0],stride=4,out_shape=16)
        self.S2 = CORblock_S(64,512,kernel_size=kernel_list[1],stride=2,out_shape=4)

        self.pool = nn.Sequential(OrderedDict([
               ('avgpool', nn.AdaptiveAvgPool2d(1)),
               ('flatten', Flatten()),
        ]))

        #### the distillation is not work here, so we increase the model size ?
        # self.pool = nn.Sequential(OrderedDict([
        #        ('flatten', Flatten()),
        #        ('linear',nn.Linear(512*4*4,512)),
        #        ('act',nn.Relu(inplace=True))
        # ]))

        self.classifier = nn.Linear(512,num_classes)

    def forward(self,inp,**kwargs):
        x = self.S1(inp)
        # print(x.shape)
        x = self.S2(x)
        # print(x.shape)
        pool = self.pool(x)
        pred = self.classifier(pool)

        return pred,[x,pool]

###
class CORnet_RT_Finelite(nn.Module):
    """
    This lite FineNet

    feedback_mode: 1) "upsample_pconv", upsampling + point convolution
                   2) "dsconv_upsample_gate", point convolution + upsample + sigmoid
    """

    def __init__(self, times=1,num_classes=10,top_layer="IT",
                       low_layer="V2",in_channels=3,feedback_mode="upsample_pconv",**kwargs):
        super().__init__()
        self.times = times

        self.V1 = CORblock_S(in_channels, 64, kernel_size=3, stride=1, out_shape=32)
        self.V2 = CORblock_S(64,  128,  stride=2, out_shape=16)
        self.V4 = CORblock_S(128, 256,  stride=2, out_shape=8)
        self.IT = CORblock_S(256, 512,  stride=2, out_shape=4)

        self.top_layer = top_layer
        self.low_layer = low_layer

        self.pool = nn.Sequential(OrderedDict([
               ('avgpool', nn.AdaptiveAvgPool2d(1)),
               ('flatten', Flatten()),
        ]))

        filter_dict = dict(IT=512,V4=256,V2=128,V1=64)

        scale_factor = int(filter_dict[top_layer]/filter_dict[low_layer])
        if scale_factor == 1:
            self.fb_process = nn.Sequential(
                         Identity()
                         )
        else:
            if feedback_mode == "upsample_pconv":
                self.fb_process = nn.Sequential(
                         nn.Upsample(scale_factor=scale_factor,mode='bilinear'),
                         nn.Conv2d(filter_dict[top_layer],filter_dict[low_layer],kernel_size=1,bias=False),
                         nn.GroupNorm(32,filter_dict[low_layer]),
                         nn.ReLU(inplace=True)
                         )
            elif feedback_mode == "upsample_pconv_gate":
                self.fb_process = nn.Sequential(
                         nn.Upsample(scale_factor=scale_factor,mode='bilinear'),
                         nn.Conv2d(filter_dict[top_layer],filter_dict[low_layer],kernel_size=1,bias=False),
                         nn.GroupNorm(32,filter_dict[low_layer]),
                         nn.Sigmoid()
                         )
        self.classifier = nn.Linear(512,num_classes)

    def forward(self, inp, fast_feedback=None,feedback_time=0,add_noise=False,noise_ratio=0.1,**kwargs):

        batch_size = inp.shape[0]

        # pdb.set_trace()
        if fast_feedback is not None and feedback_time==0:
            fast_feedback = fast_feedback.reshape(-1,512,4,4)
            fd_inp = self.fb_process(fast_feedback)

        else:
            fd_inp = 0
        v1_out = self.V1(inp,fd_inp if self.low_layer=="V1" else 0, add_noise=add_noise and self.low_layer=="V1",noise_ratio=noise_ratio)
        v2_out = self.V2(v1_out,fd_inp if self.low_layer=="V2" else 0, add_noise=add_noise and self.low_layer=="V2",noise_ratio=noise_ratio)
        v4_out = self.V4(v2_out,fd_inp if self.low_layer=="V4" else 0, add_noise=add_noise and self.low_layer=="V4",noise_ratio=noise_ratio)
        it_out = self.IT(v4_out,fd_inp if self.low_layer=="IT" else 0, add_noise=add_noise and self.low_layer=="IT",noise_ratio=noise_ratio)

        out_list = []
        out_list.append(it_out.detach().view(batch_size,-1))

        out_pool = self.pool(it_out)
        pred = self.classifier(out_pool)

        out_list.append(out_pool.detach().view(batch_size,-1))
        out_list.append(pred.detach().view(batch_size,-1))


        for i in range(1,self.times):
            if self.top_layer == "IT":
                fd_inp = self.fb_process(it_out)
            elif self.top_layer == "V4":
                fd_inp = self.fb_process(v4_out)
            elif self.top_layer == "V2":
                fd_inp = self.fb_process(v2_out)
            elif self.top_layer == "V1":
                fd_inp = self.fb_process(v1_out)
            else:
                fd_inp = 0

            v1_out = self.V1(inp,fd_inp if self.low_layer=="V1" else 0)
            v2_out = self.V2(v1_out,fd_inp if self.low_layer=="V2" else 0)
            v4_out = self.V4(v2_out,fd_inp if self.low_layer=="V4" else 0)
            it_out = self.IT(v4_out,fd_inp if self.low_layer=="IT" else 0)

            out_list.append(it_out.detach().view(batch_size,-1))

        out_pool = self.pool(it_out)
        pred = self.classifier(out_pool)
        return pred,out_list+[out_pool,]  ###[it_out1,out_pool1,pred1,it_out2,out_pool2]


class CORnet_Twopass(nn.Module):
    """
    Args:
    cache_model_dict: dict contains ["source_data","target_data","mode","T"]
    """

    def __init__(self,fine_model_dict=dict(),coarse_model_dict=dict(),cache_model_dict=dict(),
                      source_data=None,target_data=None,dataloader=None,using_cache=True,finenet_type="CORnet_RT_Fine",
                      coarsenet_type="CORnet_RT_Coarse"):
        super().__init__()
        if finenet_type == "CORnet_RT_Fine":
            self.finenet = CORnet_RT_Fine(**fine_model_dict)
        elif finenet_type == "CORnet_RT_Finelite":
            self.finenet = CORnet_RT_Finelite(**fine_model_dict)
        else:
            raise ValueError("No such finenet_type : {}.".format(finenet_type))

        if coarsenet_type == "CORnet_RT_Coarse":
            self.coarsenet = CORnet_RT_Coarse(**coarse_model_dict)
        elif coarsenet_type == "CORnet_RT_Coarselite":
            self.coarsenet = CORnet_RT_Coarselite(**coarse_model_dict)
        elif coarsenet_type == "CORnet_RT_Finelite":
            self.coarsenet = CORnet_RT_Finelite(**coarse_model_dict)
        else:
            raise ValueError("No such coarsenet_type : {}.".format(coarsenet_type))

        self.source_data = source_data
        self.target_data = target_data
        self.dataloader = dataloader
        self.using_cache = using_cache

        if using_cache:
            self.cachenet = CacheMemory(source_data=self.source_data,target_data=self.target_data,**cache_model_dict)

    def update_cachememory(self):
        print("update cache memory")
        source_data = []
        target_data = []
        for i, (fine_data_i,coarse_data_i,label_i) in enumerate(self.dataloader):
            batch_size = fine_data_i.shape[0]
            fine_data_i = fine_data_i.to("cuda")
            coarse_data_i = coarse_data_i.to("cuda")
            with torch.no_grad():
                _, coarse_hids = self.coarsenet(coarse_data_i)
                source_data.append(coarse_hids[-1].reshape(batch_size,-1))
                if self.source_data is not None and self.target_data is not None:
                    cache_out = self.cachenet(coarse_hids[-1].reshape(batch_size,-1))
                else:
                    cache_out = None
                _,fine_preds = self.finenet(fine_data_i,fast_feedback=cache_out,feedback_time=0)
                target_data.append(fine_preds[-2].reshape(batch_size,-1))

        self.source_data = torch.cat(source_data,dim=0)
        self.target_data = torch.cat(target_data,dim=0)
        self.cachenet.source_data = self.source_data
        self.cachenet.target_data = self.target_data

        # pdb.set_trace()

    def forward(self,x,fine_training=True,coarse_training=True,add_noise=False,noise_ratio=0.1,pre_feats=False):
        fine_x, coarse_x = x
        if coarse_training:
            coarse_pred, coarse_hids = self.coarsenet(coarse_x)
        else:
            with torch.no_grad():
                coarse_pred, coarse_hids = self.coarsenet(coarse_x)

        pool_out = coarse_hids[-1]
        if self.using_cache:
            with torch.no_grad():
                cache_out = self.cachenet(pool_out)
        else:
            cache_out = None
        if fine_training:
            fine_pred, fine_hids = self.finenet(fine_x,fast_feedback=cache_out,feedback_time=0,add_noise=add_noise,noise_ratio=noise_ratio)
        else:
            with torch.no_grad():
                fine_pred, fine_hids = self.finenet(fine_x,fast_feedback=cache_out,feedback_time=0)
        if pre_feats:
            feats = (fine_hids,coarse_hids)
            return fine_pred, coarse_pred,feats
        else:
            return fine_pred, coarse_pred


# class CORnet_RT_Twopass(nn.Module):
#
#     def __init__(self,times=1,coarse_in_channels=1,feedback_mode="gate",num_classes=10):
#         super().__init__()
#         self.coarse_net = CORnet_RT_Coarse(times=1,in_channels=coarse_in_channels,num_classes=10)
#         self.fine_net = CORnet_RT_Fine(times=times,num_classes=10)
#         self.feedback = nn.Linear(256,512,bias=False)
#         self.feedforwad = nn.Linear(512,512,bias=False)
#         self.bias = nn.Parameter(torch.zeros(1,512))
#         self.linear = nn.Linear(512,1000)
#         self.classifier = nn.Linear(1000,num_classes)
#         self.feedback_mode = feedback_mode
#
#     def forward(self,inps):
#         inp_fine, inp_coarse = inps
#         pred_fine, out_pool_fine = self.fine_net(inp_fine)
#         pred_coarse, out_pool_coarse = self.coarse_net(inp_coarse)
#         x = self.feedforwad(out_pool_fine)
#         y = self.feedback(out_pool_coarse)
#
#         if self.feedback_mode == "add":
#             out_twopass = F.relu(x + y+self.bias)
#         if self.feedback_mode == "gate":
#             out_twopass = F.relu(x*torch.sigmoid(y)+self.bias)
#         if self.feedback_mode == "mode":
#             out_twopass = F.relu(x*(1+y)+self.bias)
#         if self.feedback_mode == "sub":
#             out_twopass = F.relu(x+self.bias) - y
#
#         out_twopass = F.relu(self.linear(out_twopass))
#         pred_twopass = self.classifier(out_twopass)
#
#         return pred_twopass,pred_fine,pred_coarse
#

if __name__ == "__main__":
    import pdb
    # net = CORnet_RT_Fine(times=2,top_layer="IT",low_layer="V1")
    net = CORnet_RT_Coarselite(in_channels=1,num_classes=10,kernel_type="small")
    # inp = torch.ones(1,3,32,32)
    inp1 = torch.ones(1,1,32,32)
    pred,out = net(inp1)


    pdb.set_trace()







###
