from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
from .cache_model import CacheMemory
from .cornet_rt import CORnet_RT_Coarselite

# from cache_model import CacheMemory

import pdb

HASH = '933c001c'

"""
This script is adapted from the cornet_rt (ganguli 2019), change the input size and kernel size, input is 32x32
not ImageNet.
"""

class Flatten(nn.Module):

    """
    Helper module for flattening input tensor to 1-D for the use in Linear modules
    """

    def forward(self, x):
        return x.view(x.size(0), -1)


class Identity(nn.Module):

    """
    Helper module that stores the current tensor. Useful for accessing by name
    """

    def forward(self, x):
        return x


class CORblock_S(nn.Module):
    """
    This model is a small network block. In the block, it only contains one convoltional layer.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, out_shape=None):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.out_shape = out_shape

        self.conv_input = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                                    stride=stride, padding=kernel_size // 2)
        self.norm_input = nn.GroupNorm(32, out_channels)
        self.nonlin_input = nn.ReLU(inplace=True)

        self.output = Identity()  # for an easy access to this block's output

    def forward(self, inp=None, state=None,add_noise=False,noise_ratio=0.1,**kwargs):
        inp = self.conv_input(inp)
        inp = self.norm_input(inp)

        if add_noise:
            inp = inp + torch.rand_like(inp)*noise_ratio

        inp = self.nonlin_input(inp)

        if state is None:  # at t=0, state is initialized to 0
            state = 0
        skip = inp + state

        output = self.output(skip)
        return output

class CORnet_RT_FineliteA(nn.Module):
    """
    This lite FineNet

    feedback_mode: 1) "upsample_pconv", upsampling + point convolution
                   2) "dsconv_upsample_gate", point convolution + upsample + sigmoid
    """

    def __init__(self, times=1,num_classes=10,top_layer="IT",
                       low_layer="V2",in_channels=3,feedback_mode="upsample_pconv",**kwargs):
        super().__init__()
        self.times = times

        self.V1 = CORblock_S(in_channels, 64, kernel_size=3, stride=1, out_shape=32)
        self.V2 = CORblock_S(64,  128,  stride=2, out_shape=16)
        self.V4 = CORblock_S(128, 256,  stride=2, out_shape=8)
        self.IT = CORblock_S(256, 512,  stride=2, out_shape=4)

        self.top_layer = top_layer
        self.low_layer = low_layer

        self.pool = nn.Sequential(OrderedDict([
               ('avgpool', nn.AdaptiveAvgPool2d(1)),
               ('flatten', Flatten()),
        ]))

        filter_dict = dict(IT=512,V4=256,V2=128,V1=64)

        scale_factor = int(filter_dict[top_layer]/filter_dict[low_layer])
        if scale_factor == 1:
            self.fb_process = nn.Sequential(
                         Identity()
                         )
        else:
            if feedback_mode == "upsample_pconv":
                self.fb_process = nn.Sequential(
                         nn.Upsample(scale_factor=scale_factor,mode='bilinear'),
                         nn.Conv2d(filter_dict[top_layer],filter_dict[low_layer],kernel_size=1,bias=False),
                         nn.GroupNorm(32,filter_dict[low_layer]),
                         nn.ReLU(inplace=True)
                         )
            elif feedback_mode == "upsample_pconv_gate":
                self.fb_process = nn.Sequential(
                         nn.Upsample(scale_factor=scale_factor,mode='bilinear'),
                         nn.Conv2d(filter_dict[top_layer],filter_dict[low_layer],kernel_size=1,bias=False),
                         nn.GroupNorm(32,filter_dict[low_layer]),
                         nn.Sigmoid()
                         )
            # elif feedback_mode == "upsample_pconv_gate":
            #     self.fb_process = nn.Sequential(
            #              nn.Upsample(scale_factor=scale_factor,mode='bilinear'),
            #              nn.Conv2d(filter_dict[top_layer],filter_dict[low_layer],kernel_size=1,bias=False),
            #              nn.GroupNorm(32,filter_dict[low_layer]),
            #              nn.Sigmoid()
            #              )
            else:
                raise ValueError("No such feedback mode !")
        self.classifier = nn.Linear(512,num_classes)

    def forward(self, inp, fast_feedback=None,feedback_time=0,add_noise=False,noise_ratio=0.1,**kwargs):

        batch_size = inp.shape[0]

        if fast_feedback is not None and feedback_time==0:
            fast_feedback = fast_feedback.reshape(-1,512,4,4)
            fd_inp1 = self.fb_process(fast_feedback)
        elif fast_feedback is not None and feedback_time >0:
            fast_feedback = fast_feedback.reshape(-1,512,4,4)
            fd_inp1 = 0
        else:
            fd_inp1 = 0

        if fast_feedback is None:
            fast_feedback = 0

        v1_out = self.V1(inp,fd_inp1 if self.low_layer=="V1" else 0, add_noise=add_noise and self.low_layer=="V1",noise_ratio=noise_ratio)
        v2_out = self.V2(v1_out,fd_inp1 if self.low_layer=="V2" else 0, add_noise=add_noise and self.low_layer=="V2",noise_ratio=noise_ratio)
        v4_out = self.V4(v2_out,fd_inp1 if self.low_layer=="V4" else 0, add_noise=add_noise and self.low_layer=="V4",noise_ratio=noise_ratio)
        it_out = self.IT(v4_out,fd_inp1 if self.low_layer=="IT" else 0, add_noise=add_noise and self.low_layer=="IT",noise_ratio=noise_ratio)


        out_list = []
        out_list.append(it_out.detach().view(batch_size,-1))

        out_pool = self.pool(it_out)
        pred = self.classifier(out_pool)

        out_list.append(out_pool.detach().view(batch_size,-1))
        out_list.append(pred.detach().view(batch_size,-1))

        for i in range(1,self.times):
            if self.top_layer == "IT":
                fd_inp = self.fb_process(it_out+fast_feedback)
            elif self.top_layer == "V4":
                fd_inp = self.fb_process(v4_out)
            elif self.top_layer == "V2":
                fd_inp = self.fb_process(v2_out)
            elif self.top_layer == "V1":
                fd_inp = self.fb_process(v1_out)
            else:
                fd_inp = 0

            v1_out = self.V1(inp,fd_inp if self.low_layer=="V1" else 0)
            v2_out = self.V2(v1_out,fd_inp if self.low_layer=="V2" else 0)
            v4_out = self.V4(v2_out,fd_inp if self.low_layer=="V4" else 0)
            it_out = self.IT(v4_out,fd_inp if self.low_layer=="IT" else 0)

            out_list.append(it_out.detach().view(batch_size,-1))

        out_pool = self.pool(it_out)
        pred = self.classifier(out_pool)

        return pred,out_list+[out_pool,]  ###[it_out1,out_pool1,pred1,it_out2,out_pool2]
        # return pred


class CORnet_TwopassA(nn.Module):
    """
    Args:
    cache_model_dict: dict contains ["source_data","target_data","mode","T"]
    """

    def __init__(self,fine_model_dict=dict(),coarse_model_dict=dict(),cache_model_dict=dict(),
                      source_data=None,target_data=None,dataloader=None,using_cache=True,finenet_type="CORnet_RT_Fine",
                      coarsenet_type="CORnet_RT_Coarse"):
        super().__init__()
        if finenet_type == "CORnet_RT_Fine":
            self.finenet = CORnet_RT_Fine(**fine_model_dict)
        elif finenet_type == "CORnet_RT_FineliteA":
            self.finenet = CORnet_RT_FineliteA(**fine_model_dict)
        else:
            raise ValueError("No such finenet_type : {}.".format(finenet_type))

        if coarsenet_type == "CORnet_RT_Coarse":
            self.coarsenet = CORnet_RT_Coarse(**coarse_model_dict)
        elif coarsenet_type == "CORnet_RT_Coarselite":
            self.coarsenet = CORnet_RT_Coarselite(**coarse_model_dict)
        elif coarsenet_type == "CORnet_RT_Finelite":
            self.coarsenet = CORnet_RT_Finelite(**coarse_model_dict)
        else:
            raise ValueError("No such coarsenet_type : {}.".format(coarsenet_type))

        self.source_data = source_data
        self.target_data = target_data
        self.dataloader = dataloader
        self.using_cache = using_cache

        self.update_flag = True

        if using_cache:
            self.cachenet = CacheMemory(source_data=self.source_data,target_data=self.target_data,**cache_model_dict)

    def update_cachememory(self,update_flag=True):
        print("update cache memory")
        source_data = []
        target_data = []
        if update_flag:
            for i, data_i in enumerate(self.dataloader):
                batch_size = data_i[0].shape[0]
                fine_data_i = data_i[0].to("cuda")
                coarse_data_i = data_i[1].to("cuda")
                with torch.no_grad():
                    _, coarse_hids = self.coarsenet(coarse_data_i)
                    source_data.append(coarse_hids[-1].reshape(batch_size,-1))
                    if self.source_data is not None and self.target_data is not None:
                        cache_out = self.cachenet(coarse_hids[-1].reshape(batch_size,-1))
                    else:
                        cache_out = None
                    _,fine_preds = self.finenet(fine_data_i,fast_feedback=cache_out,feedback_time=0)
                    target_data.append(fine_preds[-2].reshape(batch_size,-1))

            self.source_data = torch.cat(source_data,dim=0)
            self.target_data = torch.cat(target_data,dim=0)
            self.cachenet.source_data = self.source_data
            self.cachenet.target_data = self.target_data
            self.update_flag = True

            print("state to update the cache memory")
        else:
            self.update_flag = False
            print("wait to update the cache memory")

        # pdb.set_trace()


    def forward(self,x,fine_training=True,feedback_time=0,coarse_training=True,add_noise=False,noise_ratio=0.1,pre_feats=False,**kwargs):
        if isinstance(x,list):
            fine_x, coarse_x = x
        else:
            fine_x, coarse_x = x.split(split_size=int(x.shape[0]/2),dim=0)  ### for adversarial attacks
            # fine_x = x
            # coarse_x = torch.zeros((fine_x.shape[0],1,32,32)).to(fine_x.device)

        batch_size = fine_x.shape[0]

        if coarse_training:
            coarse_pred, coarse_hids = self.coarsenet(coarse_x)
        else:
            with torch.no_grad():
                coarse_pred, coarse_hids = self.coarsenet(coarse_x)

        pool_out = coarse_hids[-1]

        if self.using_cache and self.update_flag:
            with torch.no_grad():
                ### this is for consistent version.
                cache_out, cache_pred= self.cachenet(pool_out,pre_feats=True)

                #### local information is also involve information processing in network processing.
                # coarse_conv = coarse_hids[-2]
                # cache_out = cache_out+coarse_conv.view(batch_size, -1)
        else:
            cache_out = None
            cache_pred = None

        # pdb.set_trace()

        if fine_training:
            fine_pred, fine_hids = self.finenet(fine_x,fast_feedback=cache_out,feedback_time=feedback_time)
        else:
            with torch.no_grad():
                fine_pred, fine_hids = self.finenet(fine_x,fast_feedback=cache_out,feedback_time=feedback_time)

        if pre_feats:
            feats = (fine_hids,coarse_hids,cache_pred)
            return fine_pred, coarse_pred,feats
        else:
            return fine_pred, coarse_pred

        # return fine_pred
        # return coarse_pred

if __name__ == "__main__":
    import pdb
    # net = CORnet_RT_Fine(times=2,top_layer="IT",low_layer="V1")
    net = CORnet_RT_Coarselite(in_channels=1,num_classes=10,kernel_type="small")
    # inp = torch.ones(1,3,32,32)
    inp1 = torch.ones(1,1,32,32)
    pred,out = net(inp1)


    pdb.set_trace()







###
