from algos.common.actor_preference import ActorSquashPreference as Actor
from algos.common.critic_preference import CriticSAPreference as Critic
from algos.common.agent_base import AgentBase
from utils import cprint

from .optimizer import MultiTROptimizer
from .storage import ReplayBuffer

import numpy as np
import torch
import os

EPS = 1e-8

def NormalKLDivLoss(mu1, log_std1, mu2, log_std2):
    # approximated version
    std1 = torch.exp(log_std1)
    std2 = torch.exp(log_std2)
    kl_div = 2.0*(std1 - std2)**2 + (mu1 - mu2)**2 
    return kl_div.mean()


class Agent(AgentBase):
    def __init__(self, args) -> None:
        super().__init__(
            name=args.name,
            device=args.device,
            obs_dim=args.obs_dim,
            action_dim=args.action_dim,
            reward_dim=args.reward_dim,
            norm_obs=args.norm_obs,
            norm_reward=args.norm_reward,
        )

        # for base
        self.save_dir = args.save_dir
        self.checkpoint_dir = f"{self.save_dir}/checkpoint"
        self.cost_dim = args.__dict__.get('cost_dim', 0)
        self.preference_dim = self.reward_dim
        self.action_bound_min = args.action_bound_min
        self.action_bound_max = args.action_bound_max
        self.n_steps = args.n_steps
        self.n_envs = args.n_envs
        self.max_episode_len = args.max_episode_len
        self.n_objs = self.reward_dim
        self.n_cons = self.cost_dim

        # for learning
        self.discount_factor = args.discount_factor
        self.critic_lr = args.critic_lr
        self.actor_lr = args.actor_lr
        self.n_actor_update_iters = args.n_actor_update_iters
        self.n_critic_update_iters = args.n_critic_update_iters
        self.max_tr_size = args.max_tr_size
        self.max_grad_norm = args.max_grad_norm
        self.batch_size = args.batch_size
        self.soft_update_ratio = args.soft_update_ratio
        self.len_replay_buffer = args.len_replay_buffer

        # for constraints
        self.con_thresholds = np.array(args.con_thresholds)
        self.con_thresholds /= (1.0 - self.discount_factor)
        assert len(self.con_thresholds) == self.n_cons

        # for entropy
        self.con_entropy = args.con_entropy
        if self.con_entropy:
            self.target_entropy = self.action_dim*args.con_ent_thresh
            self.log_entropy_alpha = torch.tensor(0.0, requires_grad=True, device=self.device)
            self.getEntropyAlpha = lambda: torch.exp(self.log_entropy_alpha)
            self.ent_alpha_lr = args.ent_alpha_lr
        else:
            self.getEntropyAlpha = lambda: torch.tensor(0.0, device=self.device)

        # declare actor and critic
        model_cfg = args.model
        self.actor = Actor(
            self.device, self.obs_dim, self.action_dim, self.action_bound_min, 
            self.action_bound_max, self.preference_dim, model_cfg['actor']).to(self.device)
        self.reward_critic1 = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, model_cfg['reward_critic']).to(self.device)
        self.reward_critic2 = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, model_cfg['reward_critic']).to(self.device)
        self.reward_critic_target1 = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, model_cfg['reward_critic']).to(self.device)
        self.reward_critic_target2 = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, model_cfg['reward_critic']).to(self.device)
        if self.cost_dim > 0:
            self.cost_critic1 = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, model_cfg['cost_critic']).to(self.device)
            self.cost_critic2 = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, model_cfg['cost_critic']).to(self.device)
            self.cost_critic_target1 = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, model_cfg['cost_critic']).to(self.device)
            self.cost_critic_target2 = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, model_cfg['cost_critic']).to(self.device)

        # declare optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.intermediate_actor_optimizer = MultiTROptimizer(
            self.actor, self.max_tr_size, self.n_objs, self.n_cons, self.con_thresholds, self.device)
        self.reward_critic_params = list(self.reward_critic1.parameters()) + list(self.reward_critic2.parameters())
        self.reward_critic_optimizer = torch.optim.Adam(self.reward_critic_params, lr=self.critic_lr)
        if self.cost_dim > 0:
            self.cost_critic_params = list(self.cost_critic1.parameters()) + list(self.cost_critic2.parameters())
            self.cost_critic_optimizer = torch.optim.Adam(self.cost_critic_params, lr=self.critic_lr)
        if self.con_entropy:
            self.entropy_alpha_optimizer = torch.optim.Adam([self.log_entropy_alpha], lr=self.ent_alpha_lr)

        # for replay buffer
        self.replay_buffer = ReplayBuffer(
            self.device, self.len_replay_buffer, self.discount_factor, self.batch_size, self.n_envs, self.max_episode_len,
            self.obs_dim, self.action_dim, self.preference_dim, self.reward_dim, self.cost_dim)

    ################
    # Public Methods
    ################

    @torch.no_grad()
    def getAction(self, state:np.ndarray, preference:np.ndarray, deterministic:bool) -> np.ndarray:
        state_tensor = torch.tensor(self.obs_rms.normalize(state), dtype=torch.float32, device=self.device)
        preference_tensor = torch.tensor(preference, dtype=torch.float32, device=self.device)
        epsilon_tensor = torch.randn(state_tensor.shape[:-1] + (self.action_dim,), device=self.device)

        self.actor.updateActionDist(state_tensor, preference_tensor, epsilon_tensor)
        norm_action_tensor, unnorm_action_tensor = self.actor.sample(deterministic)

        self.state = state.copy()
        self.preference = preference.copy()
        self.action = norm_action_tensor.detach().cpu().numpy()
        return unnorm_action_tensor.detach().cpu().numpy()

    def step(self, *args) -> None:
        if self.cost_dim == 0:
            np_reward, np_done, np_fail, np_nextState = args
            self.replay_buffer.addTransition(
                self.state, self.action, self.preference, np_reward, 
                np_done, np_fail, np_nextState)
        else:
            np_reward, np_cost, np_done, np_fail, np_nextState = args
            self.replay_buffer.addTransition(
                self.state, self.action, self.preference, np_reward, 
                np_cost, np_done, np_fail, np_nextState)

        # update statistics
        if self.norm_obs:
            self.obs_rms.update(self.state)
        if self.norm_reward:
            self.reward_rms.update(np_reward)

    def readyToTrain(self):
        return self.replay_buffer.getLen() >= 1000*self.n_critic_update_iters

    def train(self):
        # for actor update
        actor_states_tensor_list = []
        actor_preferences_tensor_list = []
        action_means_tensor_list = []
        action_log_stds_tensor_list = []
        objectives_tensor_list = []
        constraints_tensor_list = []

        # update critics and actor
        for _ in range(self.n_critic_update_iters):
            # sample batches
            batch_list = self.replay_buffer.getBatches(self.obs_rms, self.reward_rms)
            if self.cost_dim > 0:
                states_tensor, actions_tensor, rewards_tensor, costs_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list
            else:
                states_tensor, actions_tensor, rewards_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list

            # ================= update critics ================= #
            # calculate critic targets
            with torch.no_grad():
                entropy_alpha_tensor = self.getEntropyAlpha()
                epsilons_tensor = torch.randn_like(actions_tensor)
                self.actor.updateActionDist(next_states_tensor, preferences_tensor, epsilons_tensor)
                next_actions_tensor = torch.clamp(self.actor.sample(deterministic=False)[0], -1.0, 1.0) # (batch_size, action_dim)
                next_log_probs_tensor = self.actor.getLogProb() # (batch_size,)
                next_reward_values1_tensor = self.reward_critic_target1(
                    next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, reward_dim)
                next_reward_values2_tensor = self.reward_critic_target2(
                    next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, reward_dim)
                min_reward_values_tensor = torch.min(
                    next_reward_values1_tensor, next_reward_values2_tensor) # (batch_size, reward_dim)
                reward_targets_tensor = rewards_tensor + self.discount_factor * (1.0 - fails_tensor.unsqueeze(-1)) \
                    * (min_reward_values_tensor - entropy_alpha_tensor * next_log_probs_tensor.unsqueeze(-1)) # (batch_size, reward_dim)
                if self.cost_dim > 0:
                    next_cost_values1_tensor = self.cost_critic_target1(
                        next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, cost_dim)
                    next_cost_values2_tensor = self.cost_critic_target2(
                        next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, cost_dim)
                    max_cost_values_tensor = torch.max(
                        next_cost_values1_tensor, next_cost_values2_tensor) # (batch_size, cost_dim)
                    cost_targets_tensor = costs_tensor + self.discount_factor \
                        * (1.0 - fails_tensor.unsqueeze(-1)) * max_cost_values_tensor

            # reward critic update
            reward_critic_loss = self.reward_critic1.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor) \
                                + self.reward_critic2.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor)
            self.reward_critic_optimizer.zero_grad()
            reward_critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.reward_critic_params, self.max_grad_norm)
            self.reward_critic_optimizer.step()

            if self.cost_dim > 0:
                # cost critic update
                cost_critic_loss = self.cost_critic1.getLoss(states_tensor, actions_tensor, preferences_tensor, cost_targets_tensor) \
                                    + self.cost_critic2.getLoss(states_tensor, actions_tensor, preferences_tensor, cost_targets_tensor)
                self.cost_critic_optimizer.zero_grad()
                cost_critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.cost_critic_params, self.max_grad_norm)
                self.cost_critic_optimizer.step()

            # soft update
            self._softUpdate(self.reward_critic_target1, self.reward_critic1, self.soft_update_ratio)
            self._softUpdate(self.reward_critic_target2, self.reward_critic2, self.soft_update_ratio)
            if self.cost_dim > 0:
                self._softUpdate(self.cost_critic_target1, self.cost_critic1, self.soft_update_ratio)
                self._softUpdate(self.cost_critic_target2, self.cost_critic2, self.soft_update_ratio)
            # ================================================== #

            # ================== update actor ================== #
            with torch.no_grad():
                entropy_alpha_tensor = self.getEntropyAlpha()
                epsilons_tensor = torch.randn_like(actions_tensor)
                sampled_preference = preferences_tensor[0, :].detach().cpu().numpy()
                self.actor.updateActionDist(states_tensor, preferences_tensor, epsilons_tensor)

            def get_obj_con():
                self.actor.updateActionDist(states_tensor, preferences_tensor, epsilons_tensor)
                cur_actions = self.actor.sample(deterministic=False)[0]
                log_probs_tensor = self.actor.getLogProb() # (batch_size,)
                objectives = self._getObjectives(states_tensor, cur_actions, preferences_tensor) \
                    - (entropy_alpha_tensor*log_probs_tensor).mean().unsqueeze(0)
                if self.cost_dim > 0:
                    constraints = self._getConstraints(states_tensor, cur_actions, preferences_tensor)
                else:
                    constraints = torch.zeros(self.n_cons, device=self.device)
                return objectives, constraints

            # update actor's gradient
            action_means_tensor, action_log_stds_tensor, objectives, constraints \
                = self.intermediate_actor_optimizer.step(
                    get_obj_con, sampled_preference/np.max(sampled_preference), 
                    states_tensor, preferences_tensor)
            
            # store results
            actor_states_tensor_list.append(states_tensor)
            actor_preferences_tensor_list.append(preferences_tensor)
            action_means_tensor_list.append(action_means_tensor)
            action_log_stds_tensor_list.append(action_log_stds_tensor)
            objectives_tensor_list.append(objectives)
            constraints_tensor_list.append(constraints)
            # ================================================== #

            # ============== entropy alpha update ============== #
            with torch.no_grad():
                entropy = self.actor.getEntropy()
            if self.con_entropy:
                entropy_alpha_tensor = self.getEntropyAlpha()
                entropy_alpha_loss = torch.mean(entropy_alpha_tensor*(entropy - self.target_entropy).detach())
                self.entropy_alpha_optimizer.zero_grad()
                entropy_alpha_loss.backward()
                self.entropy_alpha_optimizer.step()
                self.log_entropy_alpha.data.copy_(self.log_entropy_alpha.data.clamp(-8.0, 8.0))
            # ================================================== #

        # get old actions and entropy
        with torch.no_grad():
            states_tensor = torch.cat(actor_states_tensor_list, dim=0)
            preferences_tensor = torch.cat(actor_preferences_tensor_list, dim=0)
            target_means_tensor = torch.cat(action_means_tensor_list, dim=0)
            target_log_stds_tensor = torch.cat(action_log_stds_tensor_list, dim=0)
            epsilons_tensor = torch.randn_like(target_means_tensor)
            self.actor.updateActionDist(states_tensor, preferences_tensor, epsilons_tensor)
            old_action_dists = self.actor.getDist()
            entropy = self.actor.getEntropy()

        # update universal actor
        for _ in range(self.n_actor_update_iters):
            cur_means, cur_log_stds, _ = self.actor(states_tensor, preferences_tensor)
            actor_loss = NormalKLDivLoss(cur_means, cur_log_stds, target_means_tensor, target_log_stds_tensor)
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.actor_optimizer.step()

        # get objectives and constraints
        with torch.no_grad():
            self.actor.updateActionDist(states_tensor, preferences_tensor, epsilons_tensor)
            cur_action_dists = self.actor.getDist()
            kl = torch.mean(torch.sum(
                torch.distributions.kl.kl_divergence(
                    old_action_dists, cur_action_dists
                ), dim=-1))
            objectives = torch.stack(objectives_tensor_list, dim=0).mean(dim=0)
            constraints = torch.stack(constraints_tensor_list, dim=0).mean(dim=0)

        # return results
        results = {
            'objectives': objectives.detach().cpu().numpy(),
            'actor_loss': actor_loss.item(),
            'reward_critic_loss': reward_critic_loss.item(),
            'entropy': entropy.item(),
            'ent_alpha': entropy_alpha_tensor.item(),
            'kl': kl.item(),
        }
        if self.cost_dim > 0:
            results.update({
                'constraints': constraints.detach().cpu().numpy(),
                'cost_critic_loss': cost_critic_loss.item(),
            })
        return results

    def save(self, model_num):
        # save rms
        self.obs_rms.save(self.save_dir, model_num)
        self.reward_rms.save(self.save_dir, model_num)

        # save network models
        save_dict = {
            'actor': self.actor.state_dict(),
            'reward_critic1': self.reward_critic1.state_dict(),
            'reward_critic2': self.reward_critic2.state_dict(),
            'reward_critic_target1': self.reward_critic_target1.state_dict(),
            'reward_critic_target2': self.reward_critic_target2.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
            'reward_critic_optimizer': self.reward_critic_optimizer.state_dict(),            
        }
        if self.cost_dim > 0:
            save_dict.update({
                'cost_critic1': self.cost_critic1.state_dict(),
                'cost_critic2': self.cost_critic2.state_dict(),
                'cost_critic_target1': self.cost_critic_target1.state_dict(),
                'cost_critic_target2': self.cost_critic_target2.state_dict(),
                'cost_critic_optimizer': self.cost_critic_optimizer.state_dict(),
            })
        if self.con_entropy:
            save_dict.update({
                'log_entropy_alpha': self.log_entropy_alpha.data,
                'entropy_alpha_optimizer': self.entropy_alpha_optimizer.state_dict(),
            })
        torch.save(save_dict, f"{self.checkpoint_dir}/model_{model_num}.pt")
        cprint(f'[{self.name}] save success.', bold=True, color="blue")

    def load(self, model_num):
        # load rms
        self.obs_rms.load(self.save_dir, model_num)
        self.reward_rms.load(self.save_dir, model_num)

        # load network models
        if not os.path.isdir(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        checkpoint_file = f"{self.checkpoint_dir}/model_{model_num}.pt"
        if os.path.isfile(checkpoint_file):
            checkpoint = torch.load(checkpoint_file, map_location=self.device)
            self.actor.load_state_dict(checkpoint['actor'])
            self.reward_critic1.load_state_dict(checkpoint['reward_critic1'])
            self.reward_critic2.load_state_dict(checkpoint['reward_critic2'])
            self.reward_critic_target1.load_state_dict(checkpoint['reward_critic_target1'])
            self.reward_critic_target2.load_state_dict(checkpoint['reward_critic_target2'])
            self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
            self.reward_critic_optimizer.load_state_dict(checkpoint['reward_critic_optimizer'])
            if self.cost_dim > 0:
                self.cost_critic1.load_state_dict(checkpoint['cost_critic1'])
                self.cost_critic2.load_state_dict(checkpoint['cost_critic2'])
                self.cost_critic_target1.load_state_dict(checkpoint['cost_critic_target1'])
                self.cost_critic_target2.load_state_dict(checkpoint['cost_critic_target2'])
                self.cost_critic_optimizer.load_state_dict(checkpoint['cost_critic_optimizer'])
            if self.con_entropy:
                self.log_entropy_alpha.data = checkpoint['log_entropy_alpha']
                self.entropy_alpha_optimizer.load_state_dict(checkpoint['entropy_alpha_optimizer'])
            cprint(f'[{self.name}] load success.', bold=True, color="blue")
            return int(model_num)
        else:
            self.actor.initialize()
            self.reward_critic1.initialize()
            self.reward_critic2.initialize()
            self._softUpdate(self.reward_critic_target1, self.reward_critic1, 0.0)
            self._softUpdate(self.reward_critic_target2, self.reward_critic2, 0.0)
            if self.cost_dim > 0:
                self.cost_critic1.initialize()
                self.cost_critic2.initialize()
                self._softUpdate(self.cost_critic_target1, self.cost_critic1, 0.0)
                self._softUpdate(self.cost_critic_target2, self.cost_critic2, 0.0)
            cprint(f'[{self.name}] load fail.', bold=True, color="red")
            return 0

    #################
    # Private Methods
    #################

    def _softUpdate(self, target, source, polyak):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * polyak + param.data * (1.0 - polyak))

    def _getObjectives(self, states_tensor, cur_actions, preferences):
        objectives = torch.mean(torch.min(
            self.reward_critic1(states_tensor, cur_actions, preferences), 
            self.reward_critic2(states_tensor, cur_actions, preferences)), dim=0)
        return objectives

    def _getConstraints(self, states_tensor, cur_actions, preferences):
        constraints = torch.mean(torch.max(
            self.cost_critic1(states_tensor, cur_actions, preferences), 
            self.cost_critic2(states_tensor, cur_actions, preferences)), dim=0)
        return constraints
