from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from torch.autograd import Variable
from torchvision.models import densenet121, resnet18
from torch.distributions import Normal, Independent

from .utils import MyQueue, MLP

IMG_HEIGHT = 224
NUM_CLASSES = 62


class FMoWNetwork(nn.Module):
    def __init__(self, args, weights=None, ssl_training=False):
        super(FMoWNetwork, self).__init__()
        self.args = args
        self.num_classes = NUM_CLASSES

        if args.trainer.dim_bottleneck_f is not None:
            self.bottleneck = nn.Sequential(
                nn.Linear(1024, args.trainer.dim_bottleneck_f),
                nn.BatchNorm1d(args.trainer.dim_bottleneck_f),
                nn.ReLU()
            )
            self.enc = nn.Sequential(densenet121(pretrained=True).features, nn.ReLU(),
                                     nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), self.bottleneck)
            self.feature_dim = args.trainer.dim_bottleneck_f
        else:
            self.enc = nn.Sequential(densenet121(pretrained=True).features, nn.ReLU(),
                                     nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten())
            self.feature_dim = 1024

        self.classifier = nn.Linear(self.feature_dim, self.num_classes)
        if weights is not None:
            self.load_state_dict(deepcopy(weights))

        # SimCLR projection head
        if self.args.trainer.method == 'simclr':
            from lightly.models.modules.heads import SimCLRProjectionHead
            self.projection_head = SimCLRProjectionHead(self.feature_dim, self.feature_dim, 128)

        # SwaV: projection head and prototypes
        elif self.args.trainer.method == 'swav':
            from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
            self.projection_head = SwaVProjectionHead(self.feature_dim, self.feature_dim, 128)
            self.prototypes = SwaVPrototypes(128, n_prototypes=self.feature_dim)
        self.ssl_training = ssl_training

    def forward(self, x):
        out = self.enc(x)

        if self.args.trainer.method == 'simclr' and self.ssl_training:
            return self.projection_head(out)
        elif self.args.trainer.method == 'swav' and self.ssl_training:
            out = self.projection_head(out)
            out = nn.functional.normalize(out, dim=1, p=2)
            out = self.prototypes(out)
            return out
        else:
            return self.classifier(out)


class FMoWNetwork_for_Ours(nn.Module):
    def __init__(self, args, weights=None):
        super(FMoWNetwork_for_Ours, self).__init__()
        self.args = args
        self.num_classes = NUM_CLASSES

        if args.trainer.dim_bottleneck_f is not None:
            self.bottleneck = nn.Sequential(
                nn.Linear(1024, args.trainer.dim_bottleneck_f),
                nn.BatchNorm1d(args.trainer.dim_bottleneck_f),
                nn.ReLU()
            )
            self.enc = nn.Sequential(densenet121(pretrained=True).features, nn.ReLU(),
                                     nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), self.bottleneck)
            self.feature_dim = args.trainer.dim_bottleneck_f
        else:
            self.enc = nn.Sequential(densenet121(pretrained=True).features, nn.ReLU(),
                                     nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten())
            self.feature_dim = 1024

        if weights is not None:
            self.load_state_dict(deepcopy(weights))

        self.classifier = nn.Linear(self.feature_dim, NUM_CLASSES, bias=False)
        self.knowledge_pool = MyQueue(maxsize=args.trainer.len_queue)
        self.DM_trainsample_pool = MyQueue(maxsize=args.trainer.len_DM_pool)
        self.eps = 1e-6

    def memorize(self, W):
        # W.shape: [C, D]
        self.knowledge_pool.put_item(W)

    def foward_encoder(self, x):
        f = self.enc(x)
        return f

    def foward(self, x):
        f = self.enc(x)
        logits = self.classifier(f)
        return f, logits

    def get_parameters(self, lr):
        params_list = []
        params_list.extend([
            {"params": self.enc.parameters(), 'lr': 1 * lr},
            {"params": self.classifier.parameters(), 'lr': 1 * lr},
        ]
        )
        return params_list
