import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from copy import deepcopy

import utils
from encoder import make_encoder
import data_augs as rad 
from algorithms.rad_byol import RAD_BYOL


class RAD_BYOL_NOINV(RAD_BYOL):
    """No inv loss"""
    def update_aux(self, aug_o, o, next_o, a, L, step):
        """
        Composed of one loss:
        1. Accuracy of dynamic model--BYOL loss1
        """
        proj_x0 = self.predictor.encoder(aug_o)
        pred_next_x0 = self.dynamic(proj_x0, a)
        proj_next_x = self.predictor_target.encoder(next_o).detach()

        # Loss1
        f0 = F.normalize(self.dynamic.predictor_acc(pred_next_x0), p=2, dim=1)
        f1 = F.normalize(proj_next_x, p=2, dim=1)
        acc_loss = F.mse_loss(f0, f1)

        # update
        aux_loss = acc_loss
        self.predictor_optimizer.zero_grad()
        self.dynamic_optimizer.zero_grad()
        aux_loss.backward()
        self.predictor_optimizer.step()
        self.dynamic_optimizer.step()
        
        if step % self.log_interval == 0:
            L.log('train/acc_loss', acc_loss, step)
