
import math
from argparse import Namespace
from itertools import chain
from typing import Optional, Callable, Any

import torch
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.optim import Optimizer

from .simclr import SimCLR
import copy
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLProjectionHead
from lightly.models.modules.heads import BYOLPredictionHead
from lightly.models.utils import deactivate_requires_grad
from lightly.models.utils import update_momentum


class BYOL(SimCLR):
    def __init__(self, args: Namespace):
        super().__init__(args)

        self.projection_head = BYOLProjectionHead(input_dim=args.embedding_dim,
                                                  hidden_dim=args.projection_hidden_dim,
                                                  output_dim=args.projection_output_dim)
        if self.args.projection_mlp_layers <= 0:
            inout_dim = args.embedding_dim
        else:
            inout_dim = args.projection_output_dim
        self.prediction_head = BYOLPredictionHead(input_dim=inout_dim,
                                                  hidden_dim=args.prediction_hidden_dim,
                                                  output_dim=inout_dim)

        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)

        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)
        self.criterion = NegativeCosineSimilarity()
        self.backbone_momentum.eval()
        self.projection_head_momentum.eval()

    def forward(self, x, emb=False):
        y = self.backbone(x).flatten(start_dim=1)
        if emb:
            return y
        z = self.projection_head(y)
        p = self.prediction_head(z)
        return p

    def forward_momentum(self, x):
        y = self.backbone_momentum(x).flatten(start_dim=1)
        z = self.projection_head_momentum(y)
        z = z.detach()
        return z

    def _contrastive_loss(self, x, label_l):
        x0, x1 = x[:, 0], x[:, 1]
        p0 = self.forward(x0)
        z0 = self.forward_momentum(x0)
        p1 = self.forward(x1)
        z1 = self.forward_momentum(x1)
        loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0))
        stats = {}
        return loss, stats

    @classmethod
    def add_model_specific_args(cls, parent_parser):
        parser = super().add_model_specific_args(parent_parser)
        parser.add_argument('--tau', type=float, default=0.99)
        parser.add_argument('--prediction_hidden_dim', type=int, default=128)
        parser.add_argument('--prediction_mlp_layers', type=int, default=2)
        return parser

    def update_target(self, tau):
        """ copy parameters from main network to target """
        update_momentum(self.backbone, self.backbone_momentum, tau)
        update_momentum(self.projection_head, self.projection_head_momentum, tau)

    def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: int = 0) -> None:
        tau = 1 - (1 - self.args.tau) * (math.cos(math.pi * self.current_epoch / self.args.max_epochs) + 1) / 2
        self.update_target(tau)
        super().on_train_batch_end(outputs, batch, batch_idx, unused)
