from typing import Union

from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torchmeta.modules import *


def D(p: Tensor, z: Tensor) -> Tensor:
    """
    follow notation of alg 1 in the SimSiam Paper
    the cosine similarity between p and z
    :param p: BxD, with prediction
    :param z: BxD, with stop grad
    :return:
    """
    return - F.cosine_similarity(p, z.detach(), dim=-1).mean()


class MLPBlock(MetaModule):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.network = MetaSequential(
            MetaLinear(in_dim, out_dim),
            MetaBatchNorm1d(out_dim),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, params=None):
        return self.network(x, params=self.get_subdict(params, 'network'))


class ProjectionMLP(MetaModule):
    def __init__(self, in_dim, hidden_dim, feat_dim):
        super().__init__()
        self.l1 = MLPBlock(in_dim, hidden_dim)
        self.l2 = MLPBlock(hidden_dim, hidden_dim)
        self.l3 = MetaSequential(
            MetaLinear(hidden_dim, feat_dim),
            MetaBatchNorm1d(feat_dim)
        )

    def forward(self, x, params=None):
        x = self.l1(x, params=self.get_subdict(params, 'l1'))
        x = self.l2(x, params=self.get_subdict(params, 'l2'))
        x = self.l3(x, params=self.get_subdict(params, 'l3'))
        return x


class PredictionMLP(MetaModule):
    def __init__(self, feat_dim, hidden_dim):
        super().__init__()
        self.l1 = MLPBlock(feat_dim, hidden_dim)
        self.l2 = MetaLinear(hidden_dim, feat_dim)

    def forward(self, x, params=None):
        x = self.l1(x, params=self.get_subdict(params, 'l1'))
        x = self.l2(x, params=self.get_subdict(params, 'l2'))
        return x


class SimSiam(MetaModule):
    def __init__(self,
                 model: Union[nn.Module, MetaModule],
                 proj_hidden_dim=2048,
                 pred_hidden_dim=512,
                 feat_dim=2048):
        super().__init__()
        self.model = model

        model_out_dim = self.model.output_dim

        self.projector = ProjectionMLP(model_out_dim, proj_hidden_dim, feat_dim)
        self.encoder = MetaSequential(
            self.model,
            self.projector
        )
        self.predictor = PredictionMLP(feat_dim, pred_hidden_dim)

    @property
    def device(self):
        return next(self.parameters()).device

    def loss_fn(self, p1, z2, p2, z1):
        return (D(p1, z2) + D(p2, z1)) / 2

    def forward(self, x1, x2, params=None):
        z1, z2 = self.encoder(x1, params=self.get_subdict(params, 'encoder')), self.encoder(x2, params=self.get_subdict(params, 'encoder'))
        p1, p2 = self.predictor(z1, params=self.get_subdict(params, 'predictor')), self.predictor(z2, params=self.get_subdict(params, 'predictor'))
        loss = self.loss_fn(p1, z2, p2, z1)
        return loss

    def predict(self, x1, x2, params=None):
        z1, z2 = self.encoder(x1, params=self.get_subdict(params, 'encoder')), self.encoder(x2, params=self.get_subdict(params, 'encoder'))
        p1, p2 = self.predictor(z1, params=self.get_subdict(params, 'predictor')), self.predictor(z2, params=self.get_subdict(params, 'predictor'))
        return p1, z2, p2, z1

    def trainable_parameters(self):
        return [p for p in self.parameters() if p.requires_grad]
