import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from copy import deepcopy

import utils
import utilsmod
from encoder import make_encoder
import data_augs as rad 
from algorithms.rad_byol_sharedproj import RAD_BYOL_SharedProj


class RAD_BYOL_SharedProj_CMC(RAD_BYOL_SharedProj):
    """Shared Projector"""
    def update_aux(self, aug_o, o, next_o, a, L, step):
        """
        Composed of one loss:
        1. Accuracy of dynamic model--BYOL loss1
        """
        proj_x0 = self.predictor.encoder(aug_o)
        proj_x1 = self.predictor_target.encoder(o).detach()
        pred_next_x0 = self.dynamic(proj_x0, a)
        pred_next_x1 = self.dynamic(proj_x1, a)
        proj_next_x = self.predictor_target.encoder(next_o).detach()

		# node to core
        Wz = torch.matmul(self.dynamic.W, proj_next_x.T)
        logits0 = torch.matmul(pred_next_x0, Wz)
        logits0 = logits0 - torch.max(logits0, 1)[0][:, None] # for stability
        # logits1 = torch.matmul(pred_next_x1, Wz)
        # logits1 = logits1 - torch.max(logits1, 1)[0][:, None] # for stability

        # node to node
        # key2 = pred_next_x0.detach()
        # Wz2 = torch.matmul(self.dynamic.W, key2.T)
        # logits2 = torch.matmul(pred_next_x1, Wz2)
        # logits2 = logits2 - torch.max(logits2, 1)[0][:, None] # for stability

        key3 = pred_next_x1.detach()
        Wz3 = torch.matmul(self.dynamic.W, key3.T)
        logits3 = torch.matmul(pred_next_x0, Wz3)
        logits3 = logits3 - torch.max(logits3, 1)[0][:, None] # for stability

        labels = torch.arange(logits0.shape[0]).long().cuda()
        cont_loss0 = self.cross_entropy_loss(logits0, labels)
        #cont_loss1 = self.cross_entropy_loss(logits1, labels)
        #cont_loss2 = self.cross_entropy_loss(logits2, labels)
        cont_loss3 = self.cross_entropy_loss(logits3, labels)

        # update
        aux_loss = cont_loss0 + cont_loss3
        self.predictor_optimizer.zero_grad()
        self.dynamic_optimizer.zero_grad()
        aux_loss.backward()
        self.predictor_optimizer.step()
        self.dynamic_optimizer.step()
        
        if step % self.log_interval == 0:
            L.log('train/cont_loss0', cont_loss0, step)
            # L.log('train/cont_loss1', cont_loss1, step)
            # L.log('train/cont_loss2', cont_loss2, step)
            L.log('train/cont_loss3', cont_loss3, step)
