import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from math import pi, cos, e
import numpy as np
from collections import OrderedDict
from federatedscope.contrib.model.resnet import BasicBlock, Bottleneck


# Model class
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, cfg=None):
        super(ResNet, self).__init__()
        self.train_sup = (num_classes > 0)
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=True)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.output_dim = 512 * block.expansion
        if (self.train_sup):
            self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = out.view(out.size(0), -1)
        if (self.train_sup):
            out = self.linear(out)
        return out


class ResNet_basic(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, cfg=None):
        super(ResNet_basic, self).__init__()
        self.train_sup = (num_classes > 0)

        self.in_planes = 16
        self.conv1 = nn.Conv2d(3,
                               16,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(16, affine=True)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.output_dim = 512 * block.expansion
        if (self.train_sup):
            self.linear = nn.Linear(64 * block.expansion,
                                    num_classes,
                                    bias=True)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = out.view(out.size(0), -1)
        if (self.train_sup):
            out = self.linear(out)
        return out


def get_block(block):
    if (block == "BasicBlock"):
        return BasicBlock
    elif (block == "Bottleneck"):
        return Bottleneck


def ResNet18(num_classes=10, block="BasicBlock"):
    return ResNet(get_block(block), [2, 2, 2, 2], num_classes=num_classes)


def ResNet34(num_classes=10, block="BasicBlock"):
    return ResNet(get_block(block), [3, 4, 6, 3], num_classes=num_classes)


def create_backbone(name, num_classes=10, block='BasicBlock'):
    if (name == 'res18'):
        net = ResNet18(num_classes=num_classes, block=block)
    elif (name == 'res34'):
        net = ResNet34(num_classes=num_classes, block=block)

    return net


# Projector
class projection_MLP_simclr(nn.Module):
    def __init__(self, in_dim, hidden_dim=512, out_dim=512):
        super(projection_MLP_simclr, self).__init__()
        self.layer1 = nn.Linear(in_dim, hidden_dim, bias=False)
        self.layer1_bn = nn.BatchNorm1d(hidden_dim, affine=True)
        self.layer2 = nn.Linear(hidden_dim, out_dim)
        self.layer2_bn = nn.BatchNorm1d(out_dim, affine=False)

    def forward(self, x):
        x = F.relu(self.layer1_bn(self.layer1(x)))
        x = self.layer2_bn(self.layer2(x))
        return x


# SimCLR
class simclr(nn.Module):
    def __init__(self, bbone_arch):
        super(simclr, self).__init__()
        self.register_buffer("rounds_done", torch.zeros(1))

        self.backbone = create_backbone(bbone_arch, num_classes=0)
        self.projector = projection_MLP_simclr(self.backbone.output_dim,
                                               hidden_dim=512,
                                               out_dim=512)

    def forward(self, x1, x2, x3=None, deg_labels=None):
        z1, z2 = self.projector(self.backbone(x1)), self.projector(
            self.backbone(x2))

        return z1, z2


class simclr_linearprob(nn.Module):
    def __init__(self, bbone_arch, num_classes=10):
        super(simclr_linearprob, self).__init__()
        self.register_buffer("rounds_done", torch.zeros(1))

        self.backbone = create_backbone(bbone_arch, num_classes=0)
        self.linear = nn.Linear(512, num_classes, bias=True)

    def forward(self, x):
        with torch.no_grad():
            out = self.backbone(x)
        out = self.linear(out)

        return out


class simclr_supervised(nn.Module):
    def __init__(self, bbone_arch, num_classes=10):
        super(simclr_supervised, self).__init__()
        self.register_buffer("rounds_done", torch.zeros(1))

        self.backbone = create_backbone(bbone_arch, num_classes=0)
        self.linear = nn.Linear(512, num_classes, bias=True)

    def forward(self, x):
        out = self.backbone(x)
        out = self.linear(out)

        return out


def ModelBuilder(model_config, local_data):
    # You can also build models without local_data
    if model_config.type == "SimCLR":
        model = simclr(bbone_arch='res18')
        return model
    if model_config.type in ["SimCLR_linear"]:
        model = simclr_linearprob(bbone_arch='res18', num_classes=10)
        return model
    if model_config.type in ["supervised_local", "supervised_fedavg"]:
        model = simclr_supervised(bbone_arch='res18', num_classes=10)
        return model


from federatedscope.register import register_model


def get_simclr(model_config, local_data):
    model = ModelBuilder(model_config, local_data)
    return model


register_model("SimCLR", get_simclr)
register_model("SimCLR_linear", get_simclr)
