import torch
import torch.nn as nn
from torch.autograd import Variable
from einops import rearrange
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.models.modules.heads import SimCLRProjectionHead
from .utils import MyQueue


class RotatedMNISTNetwork(nn.Module):
    def __init__(self, args, num_input_channels, num_classes, ssl_training=False):
        super(RotatedMNISTNetwork, self).__init__()
        self.args = args

        self.conv1 = nn.Conv2d(num_input_channels, 64, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1)

        self.bn0 = nn.GroupNorm(8, 64)
        self.bn1 = nn.GroupNorm(8, 128)
        self.bn2 = nn.GroupNorm(8, 128)
        self.bn3 = nn.GroupNorm(8, 128)

        self.relu = nn.ReLU()

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.enc = nn.Sequential(self.conv1, self.relu, self.bn0, self.conv2, self.relu, self.bn1,
                                 self.conv3, self.relu, self.bn2, self.conv4, self.relu, self.bn3,
                                 self.avgpool)
        self.hid_dim = 128
        self.classifier = nn.Linear(128, num_classes)

        # SimCLR: projection head
        if self.args.trainer.method == 'simclr':
            self.projection_head = SimCLRProjectionHead(input_dim=128, hidden_dim=128, output_dim=128)
        # SwaV: projection head and prototypes
        elif self.args.trainer.method == 'swav':
            self.projection_head = SwaVProjectionHead(input_dim=128, hidden_dim=128, output_dim=128)
            self.prototypes = SwaVPrototypes(input_dim=128, n_prototypes=128)
        self.ssl_training = ssl_training

    def forward(self, x):
        x = self.enc(x)
        x = x.view(len(x), -1)  # [b, 128]

        if self.args.trainer.method == 'simclr' and self.ssl_training:
            return self.projection_head(x)
        elif self.args.trainer.method == 'swav' and self.ssl_training:
            x = self.projection_head(x)
            x = nn.functional.normalize(x, dim=1, p=2)
            return self.prototypes(x)
        else:
            return self.classifier(x)



class RotatedMNISTNetwork_for_Ours(nn.Module):
    def __init__(self, args, num_input_channels, num_classes):
        super(RotatedMNISTNetwork_for_Ours, self).__init__()
        self.args = args
        self.conv1 = nn.Conv2d(num_input_channels, 64, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1)

        self.bn0 = nn.GroupNorm(8, 64)
        self.bn1 = nn.GroupNorm(8, 128)
        self.bn2 = nn.GroupNorm(8, 128)
        self.bn3 = nn.GroupNorm(8, 128)

        self.relu = nn.ReLU()

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.enc = nn.Sequential(self.conv1, self.relu, self.bn0, self.conv2, self.relu, self.bn1,
                                 self.conv3, self.relu, self.bn2, self.conv4, self.relu, self.bn3,
                                 self.avgpool)
        self.feature_dim = 128

        self.classifier = nn.Linear(self.feature_dim, num_classes, bias=False)
        self.knowledge_pool = MyQueue(maxsize=args.trainer.len_queue)
        self.DM_trainsample_pool = MyQueue(maxsize=args.trainer.len_DM_pool)
        self.eps = 1e-6

    def memorize(self, W):
        # W.shape: [C, D]
        self.knowledge_pool.put_item(W)

    def foward_encoder(self, x):
        f = self.enc(x)
        f = f.view(len(f), -1)
        return f

    def foward(self, x):
        f = self.enc(x)
        f = f.view(len(f), -1)
        logits = self.classifier(f)
        return f, logits

    def get_parameters(self, lr):
        params_list = []
        params_list.extend([
                {"params": self.enc.parameters(), 'lr': 1 * lr},
                {"params": self.classifier.parameters(), 'lr': 1 * lr},
            ]
        )
        return params_list



