from algos.common.critic_distribution import CriticSAPreferDist as Critic
from algos.common.actor_preference import ActorSquashPreference as Actor
from algos.common.agent_base import AgentBase
from utils import cprint

from .optimizer import MultiTROptimizer
from .storage import ReplayBuffer

import numpy as np
import torch
import os

EPS = 1e-8

def NormalKLDivLoss(mu1, log_std1, mu2, log_std2):
    std1 = torch.exp(log_std1)
    std2 = torch.exp(log_std2)
    kl_div = 2.0*(std1 - std2)**2 + (mu1 - mu2)**2 # approximated version
    return kl_div.mean()


class Agent(AgentBase):
    def __init__(self, args) -> None:
        super().__init__(
            name=args.name,
            device=args.device,
            obs_dim=args.obs_dim,
            action_dim=args.action_dim,
            reward_dim=args.reward_dim,
            norm_obs=args.norm_obs,
            norm_reward=args.norm_reward,
        )

        # for base
        self.save_dir = args.save_dir
        self.checkpoint_dir = f"{self.save_dir}/checkpoint"
        self.cost_dim = args.__dict__.get('cost_dim', 0)
        self.preference_dim = self.reward_dim
        self.action_bound_min = args.action_bound_min
        self.action_bound_max = args.action_bound_max
        self.n_steps = args.n_steps
        self.n_envs = args.n_envs

        # for learning
        self.discount_factor = args.discount_factor
        self.n_update_steps = args.n_update_steps
        self.critic_lr = args.critic_lr
        self.actor_lr = args.actor_lr
        self.n_critic_iters = args.n_critic_iters
        self.n_actor_iters = args.n_actor_iters
        self.max_grad_norm = args.max_grad_norm
        self.gae_coeff = args.gae_coeff
        self.len_replay_buffer = args.len_replay_buffer
        self.n_target_quantiles = args.n_target_quantiles

        # for trust region
        self.damping_coeff = args.damping_coeff
        self.num_conjugate = args.num_conjugate
        self.line_decay = args.line_decay
        self.max_tr_size = args.max_tr_size

        # for constraints
        self.con_thresholds = np.array(args.con_thresholds)
        self.con_thresholds /= (1.0 - self.discount_factor)
        assert len(self.con_thresholds) == self.cost_dim

        # for entropy
        self.con_entropy = args.con_entropy
        if self.con_entropy:
            self.con_ent_thresh = -args.con_ent_thresh*self.action_dim
            self.con_thresholds = np.append(self.con_thresholds, self.con_ent_thresh)

        # declare actor and critic
        model_cfg = args.model
        self.actor = Actor(
            self.device, self.obs_dim, self.action_dim, self.action_bound_min, 
            self.action_bound_max, self.preference_dim, model_cfg['actor']).to(self.device)
        self.reward_critic = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, model_cfg['reward_critic']).to(self.device)
        if self.cost_dim > 0:
            self.cost_critic = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, model_cfg['cost_critic']).to(self.device)

        # declare optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.intermediate_actor_optimizer = MultiTROptimizer(
            self.actor, self.damping_coeff, self.num_conjugate, self.line_decay, self.max_tr_size,
            self.reward_dim, self.con_thresholds, self.con_entropy, self.device)
        self.reward_critic_optimizer = torch.optim.Adam(self.reward_critic.parameters(), lr=self.critic_lr)
        if self.cost_dim > 0:
            self.cost_critic_optimizer = torch.optim.Adam(self.cost_critic.parameters(), lr=self.critic_lr)

        # declare replay buffer
        self.replay_buffer = ReplayBuffer(
            self.device, self.len_replay_buffer, self.discount_factor, self.gae_coeff, 
            self.n_envs, self.n_steps, self.n_update_steps, self.n_target_quantiles)

    ################
    # Public Methods
    ################

    @torch.no_grad()
    def getAction(self, state:np.ndarray, preference:np.ndarray, deterministic:bool) -> np.ndarray:
        state_tensor = torch.tensor(self.obs_rms.normalize(state), dtype=torch.float32, device=self.device)
        preference_tensor = torch.tensor(preference, dtype=torch.float32, device=self.device)
        epsilon_tensor = torch.randn(state_tensor.shape[:-1] + (self.action_dim,), device=self.device)

        self.actor.updateActionDist(state_tensor, preference_tensor, epsilon_tensor)
        norm_action_tensor, unnorm_action_tensor = self.actor.sample(deterministic)
        normal_action_tensor = self.actor.getNormalAction()
        log_prob_tensor = self.actor.getLogProb()

        self.state = state.copy()
        self.preference = preference.copy()
        self.action = norm_action_tensor.detach().cpu().numpy()
        self.normal_action = normal_action_tensor.detach().cpu().numpy()
        self.log_prob = log_prob_tensor.detach().cpu().numpy()
        return unnorm_action_tensor.detach().cpu().numpy()

    def step(self, *args):
        if len(args) == 4:
            reward_vecs, dones, fails, next_states = args
            self.replay_buffer.addTransition(
                self.state, self.action, self.preference,
                self.normal_action, self.log_prob, 
                reward_vecs, dones, fails, next_states)
        elif len(args) == 5:
            reward_vecs, cost_vecs, dones, fails, next_states = args
            self.replay_buffer.addTransition(
                self.state, self.action, self.preference, 
                self.normal_action, self.log_prob,
                reward_vecs, cost_vecs, dones, fails, next_states)
        else:
            raise NotImplementedError

        # update statistics
        if self.norm_obs:
            self.obs_rms.update(self.state)
        if self.norm_reward:
            self.reward_rms.update(reward_vecs)

    def readyToTrain(self):
        return self.replay_buffer.getLen() >= self.n_update_steps

    def train(self):
        # convert to tensor
        if self.cost_dim > 0:
            states_tensor, actions_tensor, preferences_tensor, \
                reward_targets_tensor, cost_targets_tensor, con_vals_tensor = \
                    self.replay_buffer.getBatches(
                        self.obs_rms, self.reward_rms, self.actor, self.reward_critic, self.cost_critic)
        else:
            states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor = \
                self.replay_buffer.getBatches(
                    self.obs_rms, self.reward_rms, self.actor, self.reward_critic)

        # ======================= update critic ======================= #
        for _ in range(self.n_critic_iters):
            # for reward
            reward_critic_loss = self.reward_critic.getLoss(
                states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor)
            self.reward_critic_optimizer.zero_grad()
            reward_critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.reward_critic.parameters(), self.max_grad_norm)
            self.reward_critic_optimizer.step()

            # for cost
            if self.cost_dim > 0:
                cost_critic_loss = self.cost_critic.getLoss(
                    states_tensor, actions_tensor, preferences_tensor, cost_targets_tensor)
                self.cost_critic_optimizer.zero_grad()
                cost_critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.cost_critic.parameters(), self.max_grad_norm)
                self.cost_critic_optimizer.step()
        # ============================================================= #

        # ======================== update actor ======================== #
        action_means_tensor_list = []
        action_log_stds_tensor_list = []
        objectives_tensor_list = []
        constraints_tensor_list = []
        for prefer_idx in range(self.n_update_steps//self.n_steps):
            with torch.no_grad():
                start_idx = prefer_idx*self.n_steps
                end_idx = (prefer_idx + 1)*self.n_steps
                actor_states_tensor = states_tensor[start_idx:end_idx] # (n_steps, obs_dim)
                actor_preferences_tensor = preferences_tensor[start_idx:end_idx] # (n_steps, preference_dim)
                actor_actions_tensor = actions_tensor[start_idx:end_idx] # (n_steps, action_dim)
                sampled_preference = actor_preferences_tensor[0].detach().cpu().numpy()

                # backup old policy
                epsilons_tensor = torch.randn_like(actor_actions_tensor)
                self.actor.updateActionDist(actor_states_tensor, actor_preferences_tensor, epsilons_tensor)
                old_action_dists = self.actor.getDist()

            # define objective and constraint
            def get_obj_con_kl():
                self.actor.updateActionDist(actor_states_tensor, actor_preferences_tensor, epsilons_tensor)
                cur_actions = self.actor.sample(deterministic=False)[0]
                cur_action_dists = self.actor.getDist()
                objectives = self._getObjectives(
                    actor_states_tensor, cur_actions, actor_preferences_tensor)
                if self.cost_dim > 0:
                    constraints = self._getConstraints(
                        actor_states_tensor, cur_actions, actor_preferences_tensor, con_vals_tensor[prefer_idx])
                    if self.con_entropy:
                        entropy = self.actor.getEntropy()
                        constraints = torch.cat((constraints, -entropy.unsqueeze(0)), dim=0)
                else:
                    if self.con_entropy:
                        entropy = self.actor.getEntropy()
                        constraints = -entropy.unsqueeze(0)
                    else:
                        constraints = torch.zeros(len(self.con_thresholds), device=self.device)
                assert len(constraints.shape) == 1 and constraints.shape[0] == len(self.con_thresholds)
                kl = torch.mean(torch.sum(
                    torch.distributions.kl.kl_divergence(old_action_dists, cur_action_dists), dim=-1))
                return objectives, constraints, kl
            
            # update actor's gradient
            objectives, constraints, target_means, target_log_stds, target_stds, safety_mode \
                    = self.intermediate_actor_optimizer.step(get_obj_con_kl, sampled_preference, 
                                                actor_states_tensor, actor_preferences_tensor)
            action_means_tensor_list.append(target_means)
            action_log_stds_tensor_list.append(target_log_stds)
            objectives_tensor_list.append(objectives)
            constraints_tensor_list.append(constraints)

        # get old actions and entropy
        with torch.no_grad():
            target_means_tensor = torch.cat(action_means_tensor_list, dim=0)
            target_log_stds_tensor = torch.cat(action_log_stds_tensor_list, dim=0)
            epsilons_tensor = torch.randn_like(target_means_tensor)
            self.actor.updateActionDist(states_tensor, preferences_tensor, epsilons_tensor)
            old_action_dists = self.actor.getDist()
            entropy = self.actor.getEntropy()

        for _ in range(self.n_actor_iters):
            cur_means, cur_log_stds, _ = self.actor(states_tensor, preferences_tensor)
            actor_loss = NormalKLDivLoss(cur_means, cur_log_stds, target_means_tensor, target_log_stds_tensor)
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.actor_optimizer.step()

        # get objectives and constraints
        with torch.no_grad():
            self.actor.updateActionDist(states_tensor, preferences_tensor, epsilons_tensor)
            cur_action_dists = self.actor.getDist()
            kl = torch.mean(torch.sum(
                torch.distributions.kl.kl_divergence(
                    old_action_dists, cur_action_dists
                ), dim=-1))
            objectives = torch.stack(objectives_tensor_list, dim=0).mean(dim=0)
            constraints = torch.stack(constraints_tensor_list, dim=0).mean(dim=0)
        # ============================================================== #

        # return
        train_results = {
            'objectives': objectives.detach().cpu().numpy(),
            'actor_loss': actor_loss.item(),
            'reward_critic_loss': reward_critic_loss.item(),
            'entropy': entropy.item(),
            'kl': kl.item(),
        }
        if self.cost_dim > 0:
            train_results['constraints'] = constraints.detach().cpu().numpy()
            train_results['cost_critic_loss'] = cost_critic_loss.item()
        return train_results
    
    def save(self, model_num):
        # save rms
        self.obs_rms.save(self.save_dir, model_num)
        self.reward_rms.save(self.save_dir, model_num)

        # save network models
        save_dict = {
            'actor': self.actor.state_dict(),
            'reward_critic': self.reward_critic.state_dict(),
            'reward_critic_optimizer': self.reward_critic_optimizer.state_dict(),
        }
        if self.cost_dim > 0:
            save_dict['cost_critic'] = self.cost_critic.state_dict()
            save_dict['cost_critic_optimizer'] = self.cost_critic_optimizer.state_dict()
        torch.save(save_dict, f"{self.checkpoint_dir}/model_{model_num}.pt")
        cprint(f'[{self.name}] save success.', bold=True, color="blue")

    def load(self, model_num):
        # load rms
        self.obs_rms.load(self.save_dir, model_num)
        self.reward_rms.load(self.save_dir, model_num)

        # load network models
        if not os.path.isdir(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        checkpoint_file = f"{self.checkpoint_dir}/model_{model_num}.pt"
        if os.path.isfile(checkpoint_file):
            checkpoint = torch.load(checkpoint_file, map_location=self.device)
            self.actor.load_state_dict(checkpoint['actor'])
            self.reward_critic.load_state_dict(checkpoint['reward_critic'])
            self.reward_critic_optimizer.load_state_dict(checkpoint['reward_critic_optimizer'])
            if self.cost_dim > 0:
                self.cost_critic.load_state_dict(checkpoint['cost_critic'])
                self.cost_critic_optimizer.load_state_dict(checkpoint['cost_critic_optimizer'])
            cprint(f'[{self.name}] load success.', bold=True, color="blue")
            return int(model_num)
        else:
            self.actor.initialize()
            self.reward_critic.initialize()
            if self.cost_dim > 0:
                self.cost_critic.initialize()
            cprint(f'[{self.name}] load fail.', bold=True, color="red")
            return 0

    ################
    # private method
    ################

    def _getObjectives(self, states_tensor, actions_tensor, preferences_tensor):
        reward_values = self.reward_critic(states_tensor, actions_tensor, preferences_tensor) # (batch_size, reward_dim, n_critics, n_quantiles)
        objectives = reward_values.mean(dim=(0, 2, 3)) # (reward_dim,)
        return objectives

    def _getConstraints(self, states_tensor, actions_tensor, preferences_tensor, con_vals_tensor):
        cost_values = self.cost_critic(states_tensor, actions_tensor, preferences_tensor) # (batch_size, cost_dim, n_critics, n_quantiles)
        constraints = cost_values.mean(dim=(0, 2, 3)) # (cost_dim,)
        constraints = constraints - constraints.detach() + con_vals_tensor
        return constraints
