

from argparse import Namespace

from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import SimSiamPredictionHead
from lightly.models.modules import SimSiamProjectionHead

from .simclr import SimCLR


class SimSiam(SimCLR):

    def __init__(self, args: Namespace):
        super().__init__(args)
        self.projection_head = SimSiamProjectionHead(
            input_dim=args.embedding_dim,
            hidden_dim=args.projection_hidden_dim,
            output_dim=args.projection_output_dim,
        )
        if self.args.projection_mlp_layers <= 0:
            inout_dim = args.embedding_dim
        else:
            inout_dim = args.projection_output_dim
        self.prediction_head = SimSiamPredictionHead(
            input_dim=inout_dim,
            hidden_dim=args.prediction_hidden_dim,
            output_dim=inout_dim,
        )

        self.criterion = NegativeCosineSimilarity()

    @classmethod
    def add_model_specific_args(cls, parent_parser):
        parser = super().add_model_specific_args(parent_parser)
        parser.add_argument('--prediction_hidden_dim', type=int, default=128)
        parser.add_argument('--prediction_mlp_layers', type=int, default=2)
        return parser

    def forward(self, x, emb=False):
        f = self.backbone(x).flatten(start_dim=1)
        if emb:
            return f
        z = self.projection_head(f)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

    def _contrastive_loss(self, x, label):
        x0, x1 = x[:, 0], x[:, 1]
        z0, p0 = self.forward(x0)
        z1, p1 = self.forward(x1)
        loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0))
        stats = {}
        return loss, stats
