import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .blocks import *

__all__ = ['BenchResNet', 'ResNet', 'call_ResNet', 'Fusion_module']

class ILR(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, num_branches):
        ctx.num_branches = num_branches
        return input

    @staticmethod
    def backward(ctx, grad_output):
        num_branches = ctx.num_branches
        return grad_output/num_branches, None

class Fusion_module(nn.Module):
    def __init__(self, channel, num_classes, sptial, num_branches):
        super(Fusion_module, self).__init__()
        self.num_branches = num_branches
        self.fc2 = nn.Linear(channel, num_classes)
        self.conv1 = nn.Conv2d(channel*(num_branches+1), channel*(num_branches+1), kernel_size=3, stride=1, padding=1, groups=channel*(num_branches+1), bias=False)
        self.bn1 = nn.BatchNorm2d(channel*(num_branches+1))
        self.conv1_1 = nn.Conv2d(channel*(num_branches+1), channel, kernel_size=1, groups=1, bias=False)
        self.bn1_1 = nn.BatchNorm2d(channel)

        self.sptial = sptial

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, input_):
        bias = False
        atmap = []
        #input_ = torch.cat((x,y),1)

        x = F.relu(self.bn1((self.conv1(input_))))
        x = F.relu(self.bn1_1(self.conv1_1(x)))

        atmap.append(x)
        x = F.avg_pool2d(x, self.sptial)
        x = x.view(x.size(0), -1)

        out = self.fc2(x)
        atmap.append(out)

        return out


class ResNet(nn.Module):
    def __init__(self, dataset, depth, bottleneck=False, se=False, KD=False):
        super(ResNet, self).__init__()
        self.inplanes = 16

        if bottleneck is True:
            n = (depth - 2) // 9
            if se:
                block = SEBottleneck
            else:
                block = Bottleneck
        else:
            n = (depth - 2) // 6
            if se:
                block = SEBasicBlock
            else:
                block = BasicBlock
        self.block = block
        self.KD = KD

        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        else:
            raise ValueError("No valid dataset is given.")

        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, n)
        self.layer2 = self._make_layer(block, 32, n, stride=2)
        self.layer3 = self._make_layer(block, 64, n, stride=2)
        self.avgpool = nn.AvgPool2d(8)
        self.classifier = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x_f = x.view(x.size(0), -1)
        x = self.classifier(x_f)

        if self.KD == True:
            return x_f, x
        else:
            return x

class BenchResNet(nn.Module):
    def __init__(self, dataset, depth, num_branches, bench, bottleneck=False, se=False):
        super(BenchResNet, self).__init__()
        self.inplanes = 16
        self.num_branches = num_branches
        self.benchmark = bench

        if bottleneck is True:
            n = (depth - 2) // 9
            if se:
                block = SEBottleneck
            else:
                block = Bottleneck
        else:
            n = (depth - 2) // 6
            if se:
                block =SEBasicBlock
            else:
                block = BasicBlock
        self.block = block

        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        else:
            raise ValueError("No valid dataset is given.")

        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, n)
        self.layer2 = self._make_layer(block, 32, n, stride=2)
        fix_inplanes = self.inplanes

        ## OKDDip #branches = #teachers, Others #branches+1 = #teachers
        for i in range(self.num_branches+1):
            setattr(self, 'layer3_'+str(i), self._make_layer(block, 64, n, stride=2))
            self.inplanes = fix_inplanes
            setattr(self, 'classifier_'+str(i), nn.Linear(64 * block.expansion, num_classes))

        if self.benchmark.startswith('one'):
            self.control_v1 = nn.Linear(fix_inplanes, self.num_branches+1)
            self.bn_v1 = nn.BatchNorm1d(self.num_branches+1)
            self.avgpool = nn.AvgPool2d(8)
            self.avgpool_c = nn.AvgPool2d(16)
        elif self.benchmark.startswith('clilr'):
            self.avgpool = nn.AdaptiveAvgPool2d((1,1))
            self.layer_ILR = ILR.apply
        elif self.benchmark.startswith('okddip'):
            self.avgpool = nn.AdaptiveAvgPool2d((1,1))
            self.query_weight = nn.Linear(64*block.expansion, 64*block.expansion//8, bias=False)
            self.key_weight = nn.Linear(64*block.expansion, 64*block.expansion//8, bias=False)
        elif self.benchmark.startswith('ffl'):
            self.avgpool = nn.AvgPool2d(8)
        else:
            raise ValueError('You should define a benchmark what you want to run!!')

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)

        if self.benchmark.startswith('one'):
            x_c = self.avgpool_c(x)
            x_c = x_c.view(x_c.size(0), -1)

            x_c = self.control_v1(x_c)
            x_c = self.bn_v1(x_c)
            x_c = F.relu(x_c)
            x_c = F.softmax(x_c, dim=1)

            x_3 = getattr(self, 'layer3_0')(x)
            x_3 = self.avgpool(x_3)
            x_3 = x_3.view(x_3.size(0), -1)
            x_3 = getattr(self, 'classifier_0')(x_3)
            x_m = x_c[:,0].repeat(x_3.size(1), 1).transpose(0, 1) * x_3
            pro = x_3.unsqueeze(-1)

            for i in range(1, self.num_branches+1):
                en = getattr(self, 'layer3_'+str(i))(x)
                en = self.avgpool(en)
                en = en.view(en.size(0), -1)
                en = getattr(self, 'classifier_'+str(i))(en)
                x_m += x_c[:,i].repeat(en.size(1), 1).transpose(0, 1) * en
                en = en.unsqueeze(-1)
                pro = torch.cat([pro, en], -1)
            return pro, x_m

        elif self.benchmark.startswith('clilr'):
            x = self.layer_ILR(x, self.num_branches+1) # Backprop rescaling
            x_3 = getattr(self, 'layer3_0')(x)
            x_3 = self.avgpool(x_3)
            x_3 = x_3.view(x_3.size(0), -1)
            x_3 = getattr(self, 'classifier_0')(x_3)
            pro = x_3.unsqueeze(-1)

            for i in range(1, self.num_branches+1):
                en = getattr(self, 'layer3_'+str(i))(x)
                en = self.avgpool(en)
                en = en.view(en.size(0), -1)
                en = getattr(self, 'classifier_'+str(i))(en)
                en = en.unsqueeze(-1)
                pro = torch.cat([pro, en], -1)

            x_m = 0
            for i in range(1, self.num_branches+1):
                x_m += 1/(self.num_branches) * pro[:,:,i]
            x_m = x_m.unsqueeze(-1)
            for i in range(1, self.num_branches+1):
                temp = 0
                for j in range(0, self.num_branches+1):
                    if j != i:
                        temp += 1/(self.num_branches) * pro[:,:,j]
                temp = temp.unsqueeze(-1)
                x_m = torch.cat([x_m, temp], -1)
            return pro, x_m

        elif self.benchmark.startswith('okddip'):
            x_3 = getattr(self, 'layer3_0')(x)  # B x 64 x 8 x 8
            x_3 = self.avgpool(x_3)             # B x 64 x 1 x 1
            x_3 = x_3.view(x_3.size(0), -1)     # B x 64
            proj_q = self.query_weight(x_3)     # B x 8
            proj_q = proj_q[:, None, :]         # B x 1 x 8
            proj_k = self.key_weight(x_3)       # B x 8
            proj_k = proj_k[:, None, :]         # B x 1 x 8
            x_3_1 = getattr(self, 'classifier_0')(x_3)     # B x num_classes
            pro = x_3_1.unsqueeze(-1)                       # B x num_classes x 1

            for i in range(1, self.num_branches):
                temp = getattr(self, 'layer3_'+str(i))(x)
                temp = self.avgpool(temp)           # B x 64 x 1 x 1
                temp = temp.view(temp.size(0), -1)
                temp_q = self.query_weight(temp)
                temp_k = self.key_weight(temp)
                temp_q = temp_q[:, None, :]
                temp_k = temp_k[:, None, :]
                temp_1 = getattr(self, 'classifier_'+str(i))(temp)
                temp_1 = temp_1.unsqueeze(-1)
                pro = torch.cat([pro, temp_1], -1)      # B x num_classes x num_branches
                proj_q = torch.cat([proj_q, temp_q], 1) # B x num_branches x 8
                proj_k = torch.cat([proj_k, temp_k], 1)

            energy = torch.bmm(proj_q, proj_k.permute(0,2,1))
            attention = F.softmax(energy, dim=-1)
            x_m = torch.bmm(pro, attention.permute(0,2,1)) # Teacher
        
            stu = getattr(self, 'layer3_'+str(self.num_branches))(x)
            stu = self.avgpool(stu)       # B x 64 x 1 x 1
            stu = stu.view(stu.size(0), -1)
            stu = getattr(self, 'classifier_'+str(self.num_branches))(stu) # Student
            return pro, x_m, stu

        elif self.benchmark.startswith('ffl'):
            x_3 = getattr(self, 'layer3_0')(x)
            fmap = x_3
            x_3 = self.avgpool(x_3)
            x_3 = x_3.view(x_3.size(0), -1)
            x_3 = getattr(self, 'classifier_0')(x_3)
            pro = x_3.unsqueeze(-1)

            for i in range(1, self.num_branches+1):
                en = getattr(self, 'layer3_'+str(i))(x)
                fmap = torch.cat([fmap, en],1)
                en = self.avgpool(en)
                en = en.view(en.size(0), -1)
                en = getattr(self, 'classifier_'+str(i))(en)
                en = en.unsqueeze(-1)
                pro = torch.cat([pro, en], -1)

            return pro, fmap


def call_ResNet(dataset, depth, bottleneck=False, **kwargs):
    model = ResNet(dataset=dataset, depth=depth, bottleneck=bottleneck, **kwargs)
    return model

if __name__ == '__main__':
    model = BenchResNet('cifar100', 20, 4, 'sla', 4)
    print(model)
    x = torch.randn(2, 3, 32, 32)
    #pro, x_m, s = model(x) # okddip
    pro, x_m = model(x) # one, clilr
    print(pro.size(), x_m.size())
