import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from copy import deepcopy

import utils
import utilsmod
from encoder import make_encoder
import data_augs as rad 
from algorithms.rad_byol_sharedproj import RAD_BYOL_SharedProj


class RAD_BYOL_SharedProj_AUG(RAD_BYOL_SharedProj):
    """Shared Projector with aug ot+1"""
    def update(self, replay_buffer, L, step):
        obs, action, reward, next_obs, not_done, aug_obs, aug_next_obs = replay_buffer.sample_double_batch(self.augs_funcs)
    
        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(aug_obs, action, reward, aug_next_obs, not_done, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(aug_obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )
        
        if step % self.cpc_update_freq == 0 and self.encoder_type == 'pixel':
            self.update_aux(aug_obs, obs, aug_next_obs, action, L, step)
            utils.soft_update_params(
                self.predictor, self.predictor_target,
                self.soda_tau
            )

    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step))
        )
        # print(self)
        self.critic.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step))
        )
        # print(self)
        self.dynamic.load_state_dict(
            torch.load('%s/dynamic_%s.pt' % (model_dir, step))
        )
        # print(self)
        self.predictor.load_state_dict(
            torch.load('%s/predictor_%s.pt' % (model_dir, step))
        )
        # print(self)