'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from simple import SimpleNet
from torch.autograd import Variable
import logging

logger = logging.getLogger("logger")



class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1,BN_layer=nn.BatchNorm2d):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        
        # self.bn1 = nn.BatchNorm2d(planes)
        self.bn1 = BN_layer(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = BN_layer(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                BN_layer(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(SimpleNet):
    def __init__(self, block, num_blocks, num_classes=10, dataset='cifar10',alpha=0,dimZ=256,device=None,BN_layer=nn.BatchNorm2d):
        super(ResNet, self).__init__(dataset)
        # self._register_load_state_dict_pre_hook(self.sd_hook)
        self.in_planes = 32
        if dataset == 'femnist':
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False)
        elif dataset in ['cifar10','cifar100','tinyimagenet']:
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = BN_layer(self.in_planes)
        self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1,bn_norm=BN_layer)
        self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=2,bn_norm=BN_layer)
        self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=2,bn_norm=BN_layer)
        self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2,bn_norm=BN_layer)        
        
        self.dimZ = dimZ
        self.alpha = alpha
        # self.device = torch.device('cuda:0')
    
        self.feature = nn.Linear(256*block.expansion, 2*self.dimZ)
        self.linear = nn.Linear(self.dimZ, num_classes)
        self.device = device
        # self.linear = nn.Linear(256*block.expansion, num_classes)
        # self.weight_keys = 

    def _make_layer(self, block, planes, num_blocks, stride,bn_norm=nn.BatchNorm2d):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride,BN_layer=bn_norm))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def sd_hook(self, state_dict, *_):
        slt_nn_hook(self, state_dict)
    def forward(self,x,num_samples):
        batch_size =x.size()[0]
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        # for test tinyimagenet
        out = F.avg_pool2d(out, 4)
        # out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        # out = self.linear(out)
        
        encoder_output = self.feature(out)
        encoder_Z_dir = self.encoder_result(encoder_output)
        to_decoder = self.sample_encoder_Z(batch_size=batch_size,encoder_Z_distr=encoder_Z_dir,num_samples=num_samples)
        decoder_logits = self.linear(to_decoder)

        # out = self.layer4(out)
        # out = F.avg_pool2d(out, 4)
        # out = out.view(out.size(0), -1)
        # out = self.linear(out)

        regL2R  = torch.norm(to_decoder)
        # return encoder_Z_dir,decoder_logits
        return encoder_Z_dir,decoder_logits,regL2R

        # return out,encoder_Z_dir,decoder_logits,regL2R

    def gaussian_noise(self, num_samples, K):
        # works with integers as well as tuples

        return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).to(self.device)#返回一个正态分布，均值为0，方差为1

    def encoder_result(self, encoder_output):
        mu = encoder_output[:, :self.dimZ]
        sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:] - self.alpha)

        return mu, sigma
    def sample_encoder_Z(self, batch_size, encoder_Z_distr, num_samples):

        mu, sigma = encoder_Z_distr

        return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ)

class SupConResNet_backbone(SimpleNet):
    def __init__(self, block, num_blocks, num_classes=10, name=None, created_time=None, dataset='cifar'):
        super(SupConResNet_backbone, self).__init__(name, created_time)
        self.in_planes = 32

        if dataset == 'emnist':
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False)
        elif dataset == 'cifar10':
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2)
        #self.head = nn.Sequential(nn.Linear(256, 256),
        #                        nn.ReLU(inplace=True),
        #                        nn.Linear(256, 128)
        #                       )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        #logger.info(f'after layer4 data is {out}')
        out = self.avgpool(out)
        #logger.info(f'after avgpool data is {out}')
        out = torch.flatten(out, 1)
        #logger.info(f'after flatten data is {out}')
        out = F.normalize(out, dim=1)
        #logger.info(f'after normalize and out data is {out}')
        return out


def SupConResNet18(name=None, created_time=None, dataset='cifar'):
    return SupConResNet_backbone(BasicBlock, [2,2,2,2],name='{0}_SupConResNet_18'.format(name), created_time=created_time, dataset=dataset)

def SupConResNet34(name=None, created_time=None, dataset='cifar'):
    return SupConResNet_backbone(BasicBlock, [3,4,6,3],name='{0}_SupConResNet_34'.format(name), created_time=created_time, dataset=dataset)

def SupConResNet50(name=None, created_time=None, dataset='cifar'):
    return SupConResNet_backbone(Bottleneck, [3,4,6,3],name='{0}_SupConResNet_50'.format(name), created_time=created_time, dataset=dataset)

def ResNet18(device,num_classes=10, dataset='cifar',alpha=0,dimZ=256,bn_layer=None):
    if bn_layer is None:
        return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, dataset=dataset,alpha=alpha,dimZ=dimZ,device=device)
    else:
        return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, dataset=dataset,alpha=alpha,dimZ=dimZ,device=device,BN_layer=bn_layer)

def ResNet34(name=None, created_time=None, num_classes=10, dataset='cifar'):
    return ResNet(BasicBlock, [3,4,6,3],name='{0}_ResNet_34'.format(name), created_time=created_time, num_classes=num_classes, dataset=dataset)

def ResNet50(name=None, created_time=None, num_classes=10, dataset='cifar'):
    return ResNet(Bottleneck, [3,4,6,3],name='{0}_ResNet_50'.format(name), created_time=created_time, num_classes=num_classes, dataset=dataset)

def ResNet101(name=None, created_time=None):
    return ResNet(Bottleneck, [3,4,23,3],name='{0}_ResNet'.format(name), created_time=created_time)

def ResNet152(name=None, created_time=None):
    return ResNet(Bottleneck, [3,8,36,3],name='{0}_ResNet'.format(name), created_time=created_time)


def test():
    net = ResNet18()
    y = net(Variable(torch.randn(1,3,32,32)))
    print(y.size())


def slt_nn_hook(module, state_dict, *_):
    modules_dict = dict(module.named_modules())
    if 'frozen' in state_dict:
        frozen = state_dict.pop('frozen')
    else:
        frozen = []

    for key in state_dict:
        module_key = '.'.join(key.split('.')[0:-1])

        # The state dict contains an extra key 'frozen' that keeps track which layers are frozen
        if any(key.startswith(frozen_key) for frozen_key in frozen):
            requires_grad = False
        else:
            requires_grad = True

        if key.endswith('.weight'):
            modules_dict[module_key].weight = torch.nn.Parameter(torch.zeros(state_dict[key].shape), requires_grad=requires_grad)
        elif key.endswith('bias'):
            modules_dict[module_key].bias = torch.nn.Parameter(torch.zeros(state_dict[key].shape), requires_grad=requires_grad)
        elif key.endswith('running_mean'):
            modules_dict[module_key].running_mean = torch.nn.Parameter(torch.zeros(state_dict[key].shape), requires_grad=requires_grad)
        elif key.endswith('running_var'):
            modules_dict[module_key].running_var = torch.nn.Parameter(torch.zeros(state_dict[key].shape), requires_grad=requires_grad)
        else:
            raise NotImplementedError
