from itertools import chain
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import get_model, get_head
from .base import BaseMethod
from .norm_mse import norm_mse_loss


class BYOL(BaseMethod):
    """ implements BYOL loss https://arxiv.org/abs/2006.07733 """

    def __init__(self, cfg):
        """ init additional target and predictor networks """
        super().__init__(cfg)
        self.pred = nn.Sequential(
            nn.Linear(cfg.emb, cfg.head_size),
            nn.BatchNorm1d(cfg.head_size),
            nn.ReLU(),
            nn.Linear(cfg.head_size, cfg.emb),
        )
        self.model_t, _ = get_model(cfg.arch, cfg.dataset)
        self.head_t = get_head(self.out_size, cfg)
        for param in chain(self.model_t.parameters(), self.head_t.parameters()):
            param.requires_grad = False
        self.update_target(0)
        self.byol_tau = cfg.byol_tau
        self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss

    def update_target(self, tau):
        """ copy parameters from main network to target """
        for t, s in zip(self.model_t.parameters(), self.model.parameters()):
            t.data.copy_(t.data * tau + s.data * (1.0 - tau))
        for t, s in zip(self.head_t.parameters(), self.head.parameters()):
            t.data.copy_(t.data * tau + s.data * (1.0 - tau))

    def forward(self, samples):
        z = [self.pred(self.head(self.model(x))) for x in samples]
        with torch.no_grad():
            zt = [self.head_t(self.model_t(x)) for x in samples]

        loss = 0
        for i in range(len(samples) - 1):
            for j in range(i + 1, len(samples)):
                loss += self.loss_f(z[i], zt[j]) + self.loss_f(z[j], zt[i])
        loss /= self.num_pairs
        return loss

    def step(self, progress):
        """ update target network with cosine increasing schedule """
        tau = 1 - (1 - self.byol_tau) * (math.cos(math.pi * progress) + 1) / 2
        self.update_target(tau)