import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from copy import deepcopy

import utils
import utilsmod
from encoder import make_encoder
import data_augs as rad 
from algorithms.rad_byol_sharedproj import RAD_BYOL_SharedProj


class RAD_BYOL_SharedProj_NOINV(RAD_BYOL_SharedProj):
    """Shared Projector"""
    """Add dynamic with byol loss"""
    def update_aux(self, aug_o, o, next_o, a, L, step):
        """
        Composed of two loss:
        1. Aug. invariance of dynamic--BYOL loss1
        2. Accuracy of dynamic model--BYOL loss2
        """
        proj_x0 = self.predictor.encoder(aug_o)
        #proj_x1 = self.predictor_target.encoder(o)
        proj_next_x = self.predictor_target.encoder(next_o).detach()

        # Loss1
        pred_next_x0 = self.dynamic(proj_x0, a)
        #pred_next_x1 = self.dynamic(proj_x1, a).detach()
        #pred_next_x0_predicted = self.dynamic.predictor_inv(pred_next_x0)
        
        #g0 = F.normalize(pred_next_x0_predicted, p=2, dim=1)
        #g1 = F.normalize(pred_next_x1, p=2, dim=1)

        #inv_loss = F.mse_loss(g0, g1)

        # Loss2
        f0 = F.normalize(self.dynamic.predictor_acc(pred_next_x0), p=2, dim=1)
        f1 = F.normalize(proj_next_x, p=2, dim=1)
        acc_loss = F.mse_loss(f0, f1)

        # update
        aux_loss = acc_loss
        self.predictor_optimizer.zero_grad()
        self.dynamic_optimizer.zero_grad()
        aux_loss.backward()
        self.predictor_optimizer.step()
        self.dynamic_optimizer.step()
        
        if step % self.log_interval == 0:
            # L.log('train/inv_loss', inv_loss, step)
            L.log('train/acc_loss', acc_loss, step)