import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from copy import deepcopy

import utils
from encoder import make_encoder
import data_augs as rad 
from algorithms.rad import RAD


class RAD_BYOL(RAD):
    """Add dynamic with byol loss"""
    def update_aux(self, aug_o, o, next_o, a, L, step):
        """
        Composed of two loss:
        1. Aug. invariance of dynamic--BYOL loss1
        2. Accuracy of dynamic model--BYOL loss2
        """
        proj_x0 = self.predictor.encoder(aug_o)
        proj_x1 = self.predictor_target.encoder(o)
        proj_next_x = self.predictor_target.encoder(next_o).detach()

        # Loss1
        pred_next_x0 = self.dynamic(proj_x0, a)
        pred_next_x1 = self.dynamic(proj_x1, a).detach()
        pred_next_x0_predicted = self.dynamic.predictor_inv(pred_next_x0)
        
        g0 = F.normalize(pred_next_x0_predicted, p=2, dim=1)
        g1 = F.normalize(pred_next_x1, p=2, dim=1)

        inv_loss = F.mse_loss(g0, g1)

        # Loss2
        f0 = F.normalize(self.dynamic.predictor_acc(pred_next_x0), p=2, dim=1)
        f1 = F.normalize(proj_next_x, p=2, dim=1)
        acc_loss = F.mse_loss(f0, f1)

        # update
        aux_loss = inv_loss + acc_loss
        self.predictor_optimizer.zero_grad()
        self.dynamic_optimizer.zero_grad()
        aux_loss.backward()
        self.predictor_optimizer.step()
        self.dynamic_optimizer.step()
        
        if step % self.log_interval == 0:
            L.log('train/inv_loss', inv_loss, step)
            L.log('train/acc_loss', acc_loss, step)

    def update(self, replay_buffer, L, step):
        obs, action, reward, next_obs, not_done, aug_obs, aug_next_obs = replay_buffer.sample_double_batch(self.augs_funcs)
    
        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(aug_obs, action, reward, aug_next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(aug_obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )
        
        if step % self.cpc_update_freq == 0 and self.encoder_type == 'pixel':
            self.update_aux(aug_obs, obs, next_obs, action, L, step)
            utils.soft_update_params(
                self.predictor, self.predictor_target,
                self.soda_tau
            )

    def save(self, model_dir, step):
        torch.save(
            self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
        )
        torch.save(
            self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
        )
        torch.save(
            self.predictor.state_dict(), '%s/predictor_%s.pt' % (model_dir, step)
        )
        torch.save(
            self.dynamic.state_dict(), '%s/dynamic_%s.pt' % (model_dir, step)
        )

    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step))
        )
        # print(self)
        self.critic.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step))
        )
        # print(self)
        self.predictor.load_state_dict(
            torch.load('%s/predictor_%s.pt' % (model_dir, step))
        )
        # print(self)
        self.dynamic.load_state_dict(
            torch.load('%s/dynamic_%s.pt' % (model_dir, step))
        )
        # print(self)
