import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from copy import deepcopy

import utils
from encoder import make_encoder
import data_augs as rad 
from algorithms.rad_byol import RAD_BYOL


class RAD_SIMSIAM(RAD_BYOL):
    """Add dynamic with simsiam loss"""
    def sym_loss(self, z0, z1):
        '''z0 is detached, z1 is predicted'''
        p = z0.detach()

        p = F.normalize(p, p=2, dim=1)
        z = F.normalize(z1, p=2, dim=1)
        
        return -(p * z).sum(dim=1).mean()

    def update_aux(self, aug_o, o, next_o, a, L, step):
        """
        Composed of two loss:
        1. Augmentation invariance of dynamic--Simsiam loss1
        2. Accuracy of dynamic model--Simsiam loss2
        """

        proj_x0 = self.predictor.encoder(o)
        proj_x1 = self.predictor.encoder(aug_o)
        dyn_next_x0 = self.dynamic(proj_x0, a)
        dyn_next_x1 = self.dynamic(proj_x1, a)
        
        # Loss1 inv_loss
        predinv_next_x0 = self.dynamic.predictor_inv(dyn_next_x0)
        predinv_next_x1 = self.dynamic.predictor_inv(dyn_next_x1)
        inv_loss = 0.5 * (self.sym_loss(dyn_next_x0, predinv_next_x1) + self.sym_loss(dyn_next_x1, predinv_next_x0))

        # Loss2 acc_loss
        proj_next_x = self.predictor.encoder(next_o).detach()
        predacc_next_x0 = self.dynamic.predictor_acc(dyn_next_x0)
        predacc_next_x1 = self.dynamic.predictor_acc(dyn_next_x1)

        acc_loss = 0.5 * (self.sym_loss(proj_next_x, predacc_next_x0) + self.sym_loss(proj_next_x, predacc_next_x1))

        # update
        aux_loss = inv_loss + acc_loss

        # # SGD
        # self.pred_opt_SGD.zero_grad()    # sac.py
        # self.dyn_opt_SGD.zero_grad()
        # aux_loss.backward()
        # self.pred_opt_SGD.step()
        # self.dyn_opt_SGD.step()

        self.predictor_optimizer.zero_grad()
        self.dynamic_optimizer.zero_grad()
        aux_loss.backward()
        self.predictor_optimizer.step()
        self.dynamic_optimizer.step()

        if step % self.log_interval == 0:
            L.log('train/inv_loss', inv_loss, step)
            L.log('train/acc_loss', acc_loss, step)

    def update(self, replay_buffer, L, step):
        obs, action, reward, next_obs, not_done, aug_obs, aug_next_obs = replay_buffer.sample_double_batch(self.augs_funcs)
    
        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(aug_obs, action, reward, aug_next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(aug_obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )
        
        if step % self.cpc_update_freq == 0 and self.encoder_type == 'pixel':
            # There is no EMA
            self.update_aux(aug_obs, obs, next_obs, action, L, step)
