import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import logging
import torchvision.models as models

from style_op import style_inject, get_styles, normalize_style


class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                                                                padding=0, bias=False) or None

    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)


class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)

    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


class WideResNet(nn.Module):
    """ Based on code from https://github.com/yaodongyu/TRADES """
    def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True):
        super(WideResNet, self).__init__()
        nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        assert ((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        if sub_block1:
            # 1st sub-block
            self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear) and not m.bias is None:
                m.bias.data.zero_()

    """
    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)
    """

    def forward(self, x, style_index=None, styles=None, inject_layer=[], 
                embed_only=False, return_style=False, return_feat=False,
                norm_layer=[]):
        out = self.conv1(x)
        var0, mu0 = get_styles(out.detach())
        if 0 in inject_layer:
            out = style_inject(out, styles[0], style_index=style_index)
        if 0 in norm_layer:
            out = normalize_style(out)

        out = self.block1(out)
        var1, mu1 = get_styles(out.detach())
        if 1 in inject_layer:
            out = style_inject(out, styles[1], style_index=style_index)
        if 1 in norm_layer:
            out = normalize_style(out)

        feat2 = self.block2(out)
        var2, mu2 = get_styles(feat2.detach())
        if 2 in inject_layer:
            feat2 = style_inject(feat2, styles[2], style_index=style_index)
        if 2 in norm_layer:
            feat2 = normalize_style(feat2)

        out = self.block3(feat2)
        var3, mu3 = get_styles(out.detach())
        if 3 in inject_layer:
            out = style_inject(out, styles[3], style_index=style_index)
        if 3 in norm_layer:
            out = normalize_style(out)
            
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        feat = out.view(-1, self.nChannels)

        if embed_only:
            if return_style:
                return feat, feat2, var0, mu0, var1, mu1, var2, mu2, var3, mu3
            return feat, feat2

        out = self.fc(feat)

        if return_style:
            if return_feat:
                return out, feat, feat2, var0, mu0, var1, mu1, var2, mu2, var3, mu3
            return out, var0, mu0, var1, mu1, var2, mu2, var3, mu3
        if return_feat:
            return out, feat, feat2
        return out
    
    def get_style(self, x):
        style_list = []
        
        out = self.conv1(x)
        style_list.append(get_styles(out.detach()))
        out = self.block1(out)
        style_list.append(get_styles(out.detach()))
        out = self.block2(out)
        style_list.append(get_styles(out.detach()))
        out = self.block3(out)
        style_list.append(get_styles(out.detach()))
        
        return style_list

    def get_params_after_inject(self, inject_pos):
        b = []
        b.append(self.conv1)
        b.append(self.block1)
        b.append(self.block2)
        b.append(self.block3)
        b.append(self.bn1)
        b.append(self.fc)

        if 3 in inject_pos:
            b = b[4:]
        elif 2 in inject_pos:
            b = b[3:]
        elif 1 in inject_pos:
            b = b[2:]
        elif 0 in inject_pos:
            b = b[1:]

        return b


def WideResNet28(num_classes=10):
    return WideResNet(depth=28, num_classes=num_classes, widen_factor=10)


# From https://github.com/DianCh/AdaContrast/blob/master/classifier.py

class Classifier(nn.Module):
    def __init__(self, args, checkpoint_path=None):
        super().__init__()
        self.args = args
        model = None

        # 1) ResNet backbone (up to penultimate layer)
        if not self.use_bottleneck:
            model = models.__dict__[args.model.lower()](pretrained=False)
            modules = list(model.children())[:-1]
            self.encoder = nn.Sequential(*modules)
            self._output_dim = model.fc.in_features
        # 2) ResNet backbone + bottlenck (last fc as bottleneck)
        else:
            model = models.__dict__[args.model.lower()](pretrained=False)
            model.fc = nn.Linear(model.fc.in_features, args.bottleneck_dim)
            bn = nn.BatchNorm1d(args.bottleneck_dim)
            self.encoder = nn.Sequential(model, bn)
            self._output_dim = args.bottleneck_dim

        self.fc = nn.Linear(self.output_dim, 12)

        if self.use_weight_norm:
            self.fc = nn.utils.weight_norm(self.fc, dim=args.weight_norm_dim)

        if checkpoint_path:
            self.load_from_checkpoint(checkpoint_path)

        self.encoder_init()

    def encoder_init(self):
        if len(self.encoder) > 2:
            raise NotImplemented
        module_list = list(self.encoder[0].children())

        self.conv1 = module_list[0]
        self.bn1 = module_list[1]
        self.relu = module_list[2]
        self.maxpool = module_list[3]

        self.layer1 = module_list[4]
        self.layer2 = module_list[5]
        self.layer3 = module_list[6]
        self.layer4 = module_list[7]

        assert isinstance(self.layer4, nn.Sequential)

        self.avgpool = module_list[8]

        if len(module_list) == 10:
            self.remaining_layers = nn.Sequential(*module_list[9:], self.encoder[1])
        else:
            raise NotImplemented

    # def forward(self, x, return_feats=False):
    #     # 1) encoder feature
    #     feat = self.encoder(x)
    #     feat = torch.flatten(feat, 1)

    #     logits = self.fc(feat)

    #     if return_feats:
    #         return feat, logits
    #     return logits

    def forward(self, x, style_index=None, styles=None, inject_layer=[],
                embed_only=False, return_feat=False, return_style=False, norm_layer=[]):
        # 1) encoder feature
        # feat = self.encoder(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        var0, mu0 = get_styles(x.detach())
        if 0 in inject_layer:
            x = style_inject(x, styles[0], style_index=style_index)
        if 0 in norm_layer:
            x = normalize_style(x)

        x = self.maxpool(x)

        x = self.layer1(x)
        var1, mu1 = get_styles(x.detach())
        if 1 in inject_layer:
            x = style_inject(x, styles[1], style_index=style_index)
        if 1 in norm_layer:
            x = normalize_style(x)

        x = self.layer2(x)
        var2, mu2 = get_styles(x.detach())
        if 2 in inject_layer:
            x = style_inject(x, styles[2], style_index=style_index)
        if 2 in norm_layer:
            x = normalize_style(x)

        feat3 = self.layer3(x)
        var3, mu3 = get_styles(feat3.detach())
        if 3 in inject_layer:
            feat3 = style_inject(feat3, styles[3], style_index=style_index)
        if 3 in norm_layer:
            feat3 = normalize_style(feat3)

        x = self.layer4(feat3)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        feat = self.remaining_layers(x)

        # feat = torch.flatten(feat, 1)

        if embed_only:
            if return_style:
                return feat, feat3, var0, mu0, var1, mu1, var2, mu2, var3, mu3
            return feat, feat3

        logits = self.fc(feat)
        
        # if return_feat:
        #     return logits, feat
        # return logits

        if return_style:
            if return_feat:
                return logits, feat, feat3, var0, mu0, var1, mu1, var2, mu2, var3, mu3
            return logits, var0, mu0, var1, mu1, var2, mu2, var3, mu3
        if return_feat:
            return logits, feat, feat3
        return logits


    def load_from_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        state_dict = dict()
        for name, param in checkpoint["state_dict"].items():
            # get rid of 'module.' prefix brought by DDP
            name = name.replace("module.", "")
            state_dict[name] = param
        # msg = self.load_state_dict(state_dict, strict=False)
        msg = self.load_state_dict(state_dict, strict=True)
        logging.info(
            f"Loaded from {checkpoint_path}; missing params: {msg.missing_keys}"
        )

    def get_params(self):
        """
        Backbone parameters use 1x lr; extra parameters use 10x lr.
        """
        backbone_params = []
        extra_params = []
        # case 1)
        if not self.use_bottleneck:
            backbone_params.extend(self.encoder.parameters())
        # case 2)
        else:
            resnet = self.encoder[0]
            for module in list(resnet.children())[:-1]:
                backbone_params.extend(module.parameters())
            # bottleneck fc + (bn) + classifier fc
            extra_params.extend(resnet.fc.parameters())
            extra_params.extend(self.encoder[1].parameters())
            extra_params.extend(self.fc.parameters())

        # exclude frozen params
        backbone_params = [param for param in backbone_params if param.requires_grad]
        extra_params = [param for param in extra_params if param.requires_grad]

        return backbone_params, extra_params

    @property
    def num_classes(self):
        return self.fc.weight.shape[0]

    @property
    def output_dim(self):
        return self._output_dim

    @property
    def use_bottleneck(self):
        return self.args.bottleneck_dim > 0

    @property
    def use_weight_norm(self):
        return self.args.weight_norm_dim >= 0