import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.model_components import build_encoder, build_projector, build_predictor


# Adapted from https://github.com/google-deepmind/deepmind-research/tree/master/byol
class BYOL(nn.Module):
    def __init__(self, args):
        super(BYOL, self).__init__()

        self.online_encoder = build_encoder(args)
        self.online_projector = build_projector(args)
        self.predictor = build_predictor(args)

        self.target_encoder = build_encoder(args)
        self.target_projector = build_projector(args)
        self._initialize_target_network()

        self.global_step = 0
        self.loss_history = []
        self.linear_probing_history = {}


    def _initialize_target_network(self):
        self.target_encoder.load_state_dict(self.online_encoder.state_dict())
        self.target_projector.load_state_dict(self.online_projector.state_dict())

        for param in self.target_encoder.parameters():
            param.requires_grad = False

        for param in self.target_projector.parameters():
            param.requires_grad = False


    def forward(self, view_1, view_2):
        prediction_1 = self.predictor(self.online_projector(self.online_encoder(view_1)))
        prediction_2 = self.predictor(self.online_projector(self.online_encoder(view_2)))

        prediction_1 = F.normalize(prediction_1, dim=1, p=2)
        prediction_2 = F.normalize(prediction_2, dim=1, p=2)

        with torch.no_grad():
            target_representation_1 = self.target_projector(self.target_encoder(view_1))
            target_representation_2 = self.target_projector(self.target_encoder(view_2))

            target_representation_1 = F.normalize(target_representation_1, dim=1, p=2)
            target_representation_2 = F.normalize(target_representation_2, dim=1, p=2)

        return prediction_1, prediction_2, target_representation_1.detach(), target_representation_2.detach()


    def update_target_encoder(self, args):
        tau = 1.0 - (1.0 - args.tau) * (math.cos(math.pi * self.global_step / args.max_steps) + 1) / 2.0

        for param_online, param_target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_target.data.mul_(tau)
            param_target.data.add_(param_online.data, alpha=(1.0 - tau))

        for param_online, param_target in zip(self.online_projector.parameters(), self.target_projector.parameters()):
            param_target.data.mul_(tau)
            param_target.data.add_(param_online.data, alpha=(1.0 - tau))

        self.global_step += 1
