import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from functools import partial
from itertools import repeat
import torchvision
from torch.nn import init

class ResMlp(nn.Module):
    def __init__(self, in_features, act_layer=nn.GELU):
        super().__init__()
        self.in_dim = in_features
        val = torch.ones(in_features*2, in_features) / 2
        self.weight = nn.Parameter(data=val)
        self.act_layer = act_layer

    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=-1)
        y = x @ self.weight
        # y = self.act_layer(y)
        return y


class CADM(nn.Module):
    def __init__(self, dim, num_domains, qkv_bias=False, act_layer=nn.GELU,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6), norm2=True, beta=0.005):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.n_domains = num_domains
        print('DOMAIN NUM {}'.format(num_domains))
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.register_buffer('phi', torch.tensor([0.], requires_grad=False))
        # self.phi = 0

        self.norm2 = norm_layer(dim) if norm2 else nn.Identity()
        self.mlp = ResMlp(in_features=dim, act_layer=act_layer)
        self.beta = beta

    def forward(self, x):
        # B, C, H, W
        x = self.norm1(x)
        B, D = x.shape
        b = B // self.n_domains
        # get k and v    (B, D)
        kv = self.kv(x).reshape(B, 2, D).permute(1, 0, 2)
        k, v = kv[0], kv[1]

        # get domain prototype
        dp = []
        for i in range(self.n_domains):
            x_ = x[i*b:(i+1)*b]
            dp.append(torch.mean(x_, dim=0, keepdim=True))
        dp = torch.cat(dp, dim=0)
        # get q : (N, D)
        q = self.q(dp)

        # get A (N, B)
        attn = (q @ k.transpose(1, 0)).softmax(dim=-1)
        c_attn = (1 - attn) / torch.sum((1 - attn), dim=-1, keepdim=True)
        c_attns = [c_attn for _ in range(b)]
        c_attns = torch.cat(c_attns, dim=1).reshape(B, -1)
        phi_n = c_attns @ v
        grad_phi = phi_n
        # phi N D
        #  x  B D
        if self.training:
            if torch.sum(self.phi) == 0:
                grad_phi = phi_n
                # self.phi = grad_phi.type_as(self.phi)
                self.phi.data = grad_phi.data
            else:
                grad_phi = self.beta * self.phi + (1 - self.beta) * phi_n
                # self.phi = grad_phi.type_as(self.phi)
                self.phi.data = grad_phi.data
        else:
            grad_phi = self.phi

        output = self.mlp(x, grad_phi)
        output = self.norm2(output)
        return output


class ADNT(nn.Module):
    __factory = {
        18: torchvision.models.resnet18,
        34: torchvision.models.resnet34,
        50: torchvision.models.resnet50,
        101: torchvision.models.resnet101,
        152: torchvision.models.resnet152,
    }

    def __init__(self, depth, num_domains, pretrained=False, dropout=0,
                 num_classes=0, merge_features=True, beta=0.005):
        super(ADNT, self).__init__()
        self.pretrained = pretrained
        self.depth = depth
        print('MERGE-FEATURES', merge_features)
        # Construct base (pretrained) resnet
        if depth not in ADNT.__factory:
            raise KeyError("Unsupported depth:", depth)
        resnet = ADNT.__factory[depth](pretrained=pretrained)
        resnet.layer4[0].conv2.stride = (1,1)
        resnet.layer4[0].downsample[0].stride = (1,1)
        self.base = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.maxpool, # no relu
            resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4)
        self.gap = nn.AdaptiveAvgPool2d(1)

        self.merge_features = merge_features
        self.num_domains = num_domains
        self.dropout = dropout
        self.num_classes = num_classes

        out_planes = resnet.fc.in_features

        # Append new layers
        self.num_features = out_planes
        self.feat_bn = nn.BatchNorm1d(self.num_features)
        self.feat_bn.bias.requires_grad_(False)

        if self.dropout > 0:
            self.drop = nn.Dropout(self.dropout)
        if self.merge_features:
            print('Use Features Merge')
            self.cadm = CADM(dim=self.num_features, num_domains=num_domains, beta=beta)
        if self.num_classes > 0:
            self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False)
            init.normal_(self.classifier.weight, std=0.001)
        init.constant_(self.feat_bn.weight, 1)
        init.constant_(self.feat_bn.bias, 0)

        if not pretrained:
            self.reset_params()

    def forward(self, x, feature_withbn=False):
        x = self.base(x)

        x = self.gap(x)
        x = x.view(x.size(0), -1)

        bn_x = self.feat_bn(x)

        if self.dropout > 0:
            bn_x = self.drop(bn_x)

        if self.merge_features:
            bn_x = self.cadm(bn_x)

        if self.num_classes > 0:
            prob = self.classifier(bn_x)
        else:
            return x, bn_x

        if feature_withbn:
            return bn_x, prob

        if not self.training:
            return F.normalize(x), prob

        return x, prob

    def reset_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

        resnet = ADNT.__factory[self.depth](pretrained=self.pretrained)
        self.base[0].load_state_dict(resnet.conv1.state_dict())
        self.base[1].load_state_dict(resnet.bn1.state_dict())
        self.base[2].load_state_dict(resnet.maxpool.state_dict())
        self.base[3].load_state_dict(resnet.layer1.state_dict())
        self.base[4].load_state_dict(resnet.layer2.state_dict())
        self.base[5].load_state_dict(resnet.layer3.state_dict())
        self.base[6].load_state_dict(resnet.layer4.state_dict())



def ADNT50(**kwargs):
    return ADNT(50, **kwargs)

