import torch
import torch.nn as nn
from src.models.model_components import build_encoder, build_projector, build_predictor


# MoCo-v3 adapted from https://github.com/facebookresearch/moco-v3
class MoCo(nn.Module):
    def __init__(self, args):
        super(MoCo, self).__init__()

        self.online_encoder = build_encoder(args)
        self.online_projector = build_projector(args)
        self.predictor = build_predictor(args)

        self.momentum_encoder = build_encoder(args)
        self.momentum_projector = build_projector(args)
        self._initialize_momentum_encoder()

        self.loss_history = []
        self.linear_probing_history = {}


    def _initialize_momentum_encoder(self):
        self.momentum_encoder.load_state_dict(self.online_encoder.state_dict())
        self.momentum_projector.load_state_dict(self.online_projector.state_dict())

        for param in self.momentum_encoder.parameters():
            param.requires_grad = False

        for param in self.momentum_projector.parameters():
            param.requires_grad = False


    def forward(self, view_1, view_2):
        query_1 = self.predictor(self.online_projector(self.online_encoder(view_1)))
        query_2 = self.predictor(self.online_projector(self.online_encoder(view_2)))

        with torch.no_grad():
            key_1 = self.momentum_projector(self.momentum_encoder(view_1))
            key_2 = self.momentum_projector(self.momentum_encoder(view_2))

        return query_1, query_2, key_1.detach(), key_2.detach()


    def update_momentum_encoder(self, args):
        for param_online, param_momentum in zip(self.online_encoder.parameters(), self.momentum_encoder.parameters()):
            param_momentum.data.mul_(args.tau)
            param_momentum.data.add_(param_online.data, alpha=(1.0 - args.tau))

        for param_online, param_momentum in zip(self.online_projector.parameters(), self.momentum_projector.parameters()):
            param_momentum.data.mul_(args.tau)
            param_momentum.data.add_(param_online.data, alpha=(1.0 - args.tau))
