import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torchvision.models as models

__all__ = ['Modified_VGG']

VGG_IMAGENET_PRETRAINED_PATH = './data/ImageNet_pretrained.pth.tar'
VGG_ROTNET_CIFAR100_PRETRAINED_PATH = './data/pretrained/rotnet_cifar100_modified_vgg_reproduce_r.pth'
VGG_ROTNET_FINETUNE_CIFAR100_PRETRAINED_PATH = './data/pretrained/modified_vgg_rotnet_cifar100_seen_75.pth'
VGG_ROTNET_CIFAR10_PRETRAINED_PATH = './data/pretrained/rotnet_cifar10_modified_vgg_reproduce_r.pth'
VGG_ROTNET_CIFAR10_FINETUNE_PRETRAINED_PATH = './data/pretrained/modified_vgg_rotnet_cifar10_seen_7.pth'

def xavier_initializer(parameters_lst):
    for parameters in parameters_lst:
        nn.init.xavier_uniform_(parameters)



def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class Modified_VGG(nn.Module):
    def __init__(
            self,
            embedding_size, rotnet=False, finetuned=False, pretrained=True, is_norm=True, bn_freeze=False, dataset='cifar10'
    ):
        super(Modified_VGG, self).__init__()
        self.is_norm = is_norm
        self.F = VGG(embedding_size=embedding_size)
        if pretrained:
            if dataset == 'nus_wide':
                checkpoint = torch.load(VGG_IMAGENET_64_PRETRAINED_PATH)
                self.F.linear_projector = nn.Linear(768, 200)
                print(self.load_state_dict(checkpoint, strict=False))
                self.F.linear_projector = nn.Linear(768, embedding_size)
            elif dataset == 'cifar100':
                if rotnet:
                    checkpoint = torch.load(VGG_ROTNET_CIFAR100_PRETRAINED_PATH)
                    print(self.load_state_dict(checkpoint, strict=False))

                elif finetuned:
                    checkpoint = torch.load(VGG_ROTNET_CIFAR10_FINETUNE_PRETRAINED_PATH)
                    print(self.load_state_dict(checkpoint, strict=False))
                else:
                    checkpoint = torch.load(VGG_IMAGENET_PRETRAINED_PATH)
                    print(self.F.load_state_dict(checkpoint, strict=False))

            elif dataset == 'cifar10':
                if rotnet:
                    checkpoint = torch.load(VGG_ROTNET_CIFAR10_PRETRAINED_PATH)
                    print(self.load_state_dict(checkpoint, strict=False))
                elif finetuned:
                    checkpoint = torch.load(VGG_ROTNET_FINETUNE_CIFAR100_PRETRAINED_PATH)
                    print(self.load_state_dict(checkpoint, strict=False))
                else:
                    checkpoint = torch.load(VGG_IMAGENET_PRETRAINED_PATH)
                    print(self.F.load_state_dict(checkpoint, strict=False))

        if bn_freeze:
            for m in self.F.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)

    def forward(self, x):
        x = self.F(x)
        if self.is_norm:
            x = F.normalize(x, dim=1)
        return x


class VGG(nn.Module):

    def __init__(
        self,
        embedding_size
    ):
        super(VGG, self).__init__()
        self.features_1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.features_2 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        self.avgpool_1 = nn.AdaptiveAvgPool2d((1, 1))
        self.avgpool_2 = nn.AdaptiveAvgPool2d((1, 1))

        self.linear_projector = nn.Linear(768, embedding_size)
        self._initialize_weights()


    def forward(self, input):
        x = self.features_1(input)
        x_branch = self.avgpool_1(x)
        x = self.features_2(x)
        x = self.avgpool_2(x)
        x = torch.cat([x, x_branch], dim=1)
        x = x.view(x.size(0), -1)
        x = self.linear_projector(x)


        return x


    # def _initialize_weights(self) -> None:
    #     for m in self.modules():
    #         if isinstance(m, nn.Conv2d):
    #             nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    #             if m.bias is not None:
    #                 nn.init.constant_(m.bias, 0)
    #         elif isinstance(m, nn.BatchNorm2d):
    #             nn.init.constant_(m.weight, 1)
    #             nn.init.constant_(m.bias, 0)
    #         elif isinstance(m, nn.Linear):
    #             nn.init.normal_(m.weight, 0, 0.01)
    #             nn.init.constant_(m.bias, 0)

    def _initialize_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                #nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                nn.init.constant_(m.bias, 0)