from algos.common.actor_preference import ActorGaussianPreference as ActorGaussian
from algos.common.actor_preference import ActorSquashPreference as ActorSquash
from algos.common.critic_distribution import CriticSAPreferDist as Critic
from algos.common.agent_base import AgentBase
from utils import cprint

from .multiplier import Multiplier
from .storage import ReplayBuffer

from typing import Tuple
import numpy as np
import torch
import os

EPS = 1e-8

class Agent(AgentBase):
    def __init__(self, args) -> None:
        super().__init__(
            name=args.name,
            device=args.device,
            obs_dim=args.obs_dim,
            action_dim=args.action_dim,
            reward_dim=args.reward_dim,
            norm_obs=args.norm_obs,
            norm_reward=args.norm_reward,
        )

        # base
        self.save_dir = args.save_dir
        self.checkpoint_dir = f"{self.save_dir}/checkpoint"
        self.cost_dim = args.__dict__.get('cost_dim', 0)
        self.preference_dim = self.reward_dim
        self.action_bound_min = args.action_bound_min
        self.action_bound_max = args.action_bound_max
        self.n_steps = args.n_steps
        self.n_envs = args.n_envs

        # for RL
        self.discount_factor = args.discount_factor
        self.critic_lr = args.critic_lr
        self.actor_lr = args.actor_lr
        self.batch_size = args.batch_size
        self.n_update_iters = args.n_update_iters
        self.max_grad_norm = args.max_grad_norm
        self.len_replay_buffer = args.len_replay_buffer
        self.soft_update_ratio = args.soft_update_ratio
        self.model_cfg = args.model

        # for model
        assert self.model_cfg['actor']['type'] in ['gaussian', 'squash']
        Actor = ActorGaussian if self.model_cfg['actor']['type'] == 'gaussian' else ActorSquash
        self.actor = Actor(
            self.device, self.obs_dim, self.action_dim, self.action_bound_min, 
            self.action_bound_max, self.reward_dim, self.model_cfg['actor']).to(self.device)
        self.reward_critic = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        self.reward_critic_target = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        if self.cost_dim > 0:
            self.cost_critic = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            self.cost_critic_target = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            
        # for constraints
        if self.cost_dim > 0:
            self.con_lambdas_lr = args.con_lambdas_lr
            self.con_lambdas = Multiplier(self.device, self.preference_dim, self.cost_dim, args.model['multipliers']).to(self.device)
            self.con_lambdas_optimizer = torch.optim.Adam(self.con_lambdas.parameters(), lr=self.con_lambdas_lr)
            con_thresholds = np.array(args.con_thresholds)/(1.0 - self.discount_factor)
            assert len(con_thresholds) == self.cost_dim
            self.con_thresholds = torch.tensor(con_thresholds, dtype=torch.float32, device=self.device)

        # for entropy
        self.con_entropy = args.con_entropy
        if self.con_entropy:
            self.target_entropy = self.action_dim*args.entropy_threshold
            self.log_entropy_alpha = torch.tensor(0.0, requires_grad=True, device=self.device)
            self.getEntropyAlpha = lambda: torch.exp(self.log_entropy_alpha)
            self.ent_alpha_lr = args.ent_alpha_lr
        else:
            self.getEntropyAlpha = lambda: torch.tensor(0.0, device=self.device)

        # for replay buffer
        self.replay_buffer = ReplayBuffer(self.len_replay_buffer, self.batch_size, self.discount_factor, self.device)

        # for optimizers
        self.actor_params = list(self.actor.parameters())
        self.reward_critic_params = list(self.reward_critic.parameters())
        self.actor_optimizer = torch.optim.Adam(self.actor_params, lr=self.actor_lr)
        self.reward_critic_optimizer = torch.optim.Adam(self.reward_critic_params, lr=self.critic_lr)
        if self.con_entropy:
            self.entropy_alpha_optimizer = torch.optim.Adam([self.log_entropy_alpha], lr=self.ent_alpha_lr)
        if self.cost_dim > 0:
            self.cost_critic_params = list(self.cost_critic.parameters())
            self.cost_critic_optimizer = torch.optim.Adam(self.cost_critic_params, lr=self.critic_lr)

    """ public functions
    """
    @torch.no_grad()
    def getAction(self, state:np.ndarray, preference:np.ndarray, deterministic:bool) -> np.ndarray:
        state_tensor = torch.tensor(self.obs_rms.normalize(state), dtype=torch.float32, device=self.device)
        preference_tensor = torch.tensor(preference, dtype=torch.float32, device=self.device)
        epsilon_tensor = torch.randn(state_tensor.shape[:-1] + (self.action_dim,), device=self.device)

        self.actor.updateActionDist(state_tensor, preference_tensor, epsilon_tensor)
        norm_action_tensor, unnorm_action_tensor = self.actor.sample(deterministic)

        self.state = state.copy()
        self.preference = preference.copy()
        self.action = norm_action_tensor.detach().cpu().numpy()
        return unnorm_action_tensor.detach().cpu().numpy()

    def step(self, *args) -> None:
        if len(args) == 4:
            reward_vecs, dones, fails, next_states = args
            self.replay_buffer.addTransition(
                self.state, self.action, self.preference, reward_vecs, dones, fails, next_states)
        elif len(args) == 5:
            reward_vecs, cost_vecs, dones, fails, next_states = args
            self.replay_buffer.addTransition(
                self.state, self.action, self.preference, reward_vecs, cost_vecs, dones, fails, next_states)
        else:
            raise ValueError("Invalid number of arguments")

        # update statistics
        if self.norm_obs:
            self.obs_rms.update(self.state)
        if self.norm_reward:
            self.reward_rms.update(reward_vecs)

    def readyToTrain(self):
        return self.replay_buffer.getLen() >= self.batch_size

    def train(self):
        for _ in range(self.n_update_iters):
            if self.cost_dim > 0:
                reward_critic_loss, cost_critic_loss, policy_loss, entropy, entropy_alpha_tensor, con_lambdas_tensor, cost_values_tensor = self._train()
            else:
                reward_critic_loss, policy_loss, entropy, entropy_alpha_tensor = self._train()
        results = {
            'reward_critic_loss': reward_critic_loss.item(),
            'policy_loss': policy_loss.item(),
            'entropy': entropy.item(),
            'ent_alpha': entropy_alpha_tensor.item(),
        }
        if self.cost_dim > 0:
            results.update({
                'cost_critic_loss': cost_critic_loss.item(),
                'con_lambdas': con_lambdas_tensor.detach().cpu().numpy(),
                'cost_values': cost_values_tensor.detach().cpu().numpy(),
            })
        return results

    def save(self, model_num):
        # save rms
        self.obs_rms.save(self.save_dir, model_num)
        self.reward_rms.save(self.save_dir, model_num)

        # save network models
        save_dict = {
            'actor': self.actor.state_dict(),
            'reward_critic': self.reward_critic.state_dict(),
            'reward_critic_target': self.reward_critic_target.state_dict(),
            'reward_critic_optimizer': self.reward_critic_optimizer.state_dict(),            
        }
        if self.con_entropy:
            save_dict.update({
                'log_entropy_alpha': self.log_entropy_alpha.data,
                'entropy_alpha_optimizer': self.entropy_alpha_optimizer.state_dict(),
            })
        if self.cost_dim > 0:
            save_dict.update({
                'cost_critic': self.cost_critic.state_dict(),
                'cost_critic_target': self.cost_critic_target.state_dict(),
                'cost_critic_optimizer': self.cost_critic_optimizer.state_dict(),
                'con_lambdas': self.con_lambdas.state_dict(),
                'con_lambdas_optimizer': self.con_lambdas_optimizer.state_dict(),
            })
        torch.save(save_dict, f"{self.checkpoint_dir}/model_{model_num}.pt")
        cprint(f'[{self.name}] save success.', bold=True, color="blue")

    def load(self, model_num):
        # load rms
        self.obs_rms.load(self.save_dir, model_num)
        self.reward_rms.load(self.save_dir, model_num)

        # load network models
        if not os.path.isdir(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        checkpoint_file = f"{self.checkpoint_dir}/model_{model_num}.pt"
        if os.path.isfile(checkpoint_file):
            checkpoint = torch.load(checkpoint_file, map_location=self.device)
            self.actor.load_state_dict(checkpoint['actor'])
            self.reward_critic.load_state_dict(checkpoint['reward_critic'])
            self.reward_critic_target.load_state_dict(checkpoint['reward_critic_target'])
            self.reward_critic_optimizer.load_state_dict(checkpoint['reward_critic_optimizer'])
            if self.con_entropy:
                self.log_entropy_alpha.data = checkpoint['log_entropy_alpha']
                self.entropy_alpha_optimizer.load_state_dict(checkpoint['entropy_alpha_optimizer'])
            if self.cost_dim > 0:
                self.cost_critic.load_state_dict(checkpoint['cost_critic'])
                self.cost_critic_target.load_state_dict(checkpoint['cost_critic_target'])
                self.cost_critic_optimizer.load_state_dict(checkpoint['cost_critic_optimizer'])
                self.con_lambdas.load_state_dict(checkpoint['con_lambdas'])
                self.con_lambdas_optimizer.load_state_dict(checkpoint['con_lambdas_optimizer'])
            cprint(f'[{self.name}] load success.', bold=True, color="blue")
            return int(model_num)
        else:
            self.actor.initialize()
            self.reward_critic.initialize()
            self._softUpdate(self.reward_critic_target, self.reward_critic, 0.0)
            if self.cost_dim > 0:
                self.con_lambdas.initialize()
                self.cost_critic.initialize()
                self._softUpdate(self.cost_critic_target, self.cost_critic, 0.0)
            cprint(f'[{self.name}] load fail.', bold=True, color="red")
            return 0

    """ private functions
    """
    def _softUpdate(self, target, source, polyak):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * polyak + param.data * (1.0 - polyak))

    def _train(self):
        # get batches
        with torch.no_grad():
            if self.cost_dim > 0:
                states_tensor, actions_tensor, rewards_tensor, costs_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = self.replay_buffer.getBatches(self.obs_rms, self.reward_rms)
            else:
                states_tensor, actions_tensor, rewards_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = self.replay_buffer.getBatches(self.obs_rms, self.reward_rms)

        # ================== Critic Update ================== #
        # calculate critic targets
        with torch.no_grad():
            batch_size = states_tensor.shape[0]
            entropy_alpha_tensor = self.getEntropyAlpha()            
            epsilons_tensor = torch.randn_like(actions_tensor)
            self.actor.updateActionDist(next_states_tensor, preferences_tensor, epsilons_tensor)
            next_actions_tensor = self.actor.sample(deterministic=False)[0] # (batch_size, action_dim)
            next_actions_tensor = torch.clamp(next_actions_tensor, 
                                              torch.tensor(self.action_bound_min, dtype=torch.float32, device=self.device), 
                                              torch.tensor(self.action_bound_max, dtype=torch.float32, device=self.device)) 
            next_log_probs_tensor = self.actor.getLogProb() # (batch_size,)
            next_reward_quantiles_tensor = self.reward_critic_target(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, reward_dim, n_critics, n_quantiles)
            next_reward_quantiles_tensor = torch.sort(next_reward_quantiles_tensor.view(batch_size, self.reward_dim, -1), dim=-1)[0] # (batch_size, reward_dim, n_critic*n_quantiles)
            next_reward_quantiles_tensor = next_reward_quantiles_tensor[:, :, :-2]
            reward_targets_tensor = rewards_tensor.view(batch_size, self.reward_dim, 1) + \
                    self.discount_factor*(1.0 - fails_tensor.view(batch_size, 1, 1)) * \
                    (next_reward_quantiles_tensor - entropy_alpha_tensor*next_log_probs_tensor.view(batch_size, 1, 1)) # (batch_size, reward_dim, n_critic*n_quantiles)
            if self.cost_dim > 0:
                next_cost_quantiles_tensor = self.cost_critic_target(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, cost_dim, n_critics, n_quantiles)
                next_cost_quantiles_tensor = torch.sort(next_cost_quantiles_tensor.view(batch_size, self.cost_dim, -1), dim=-1)[0] # (batch_size, cost_dim, n_critic*n_quantiles)
                next_cost_quantiles_tensor = next_cost_quantiles_tensor[:, :, 2:]
                cost_targets_tensor = costs_tensor.view(batch_size, self.cost_dim, 1) + \
                        self.discount_factor*(1.0 - fails_tensor.view(batch_size, 1, 1)) * next_cost_quantiles_tensor # (batch_size, cost_dim, n_critic*n_quantiles)

        # reward critic update
        reward_critic_loss = self.reward_critic.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor)
        self.reward_critic_optimizer.zero_grad()
        reward_critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.reward_critic_params, self.max_grad_norm)
        self.reward_critic_optimizer.step()

        # soft update
        self._softUpdate(self.reward_critic_target, self.reward_critic, self.soft_update_ratio)

        if self.cost_dim > 0:
            # cost critic update
            cost_critic_loss = self.cost_critic.getLoss(states_tensor, actions_tensor, preferences_tensor, cost_targets_tensor)
            self.cost_critic_optimizer.zero_grad()
            cost_critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.cost_critic_params, self.max_grad_norm)
            self.cost_critic_optimizer.step()

            # soft update
            self._softUpdate(self.cost_critic_target, self.cost_critic, self.soft_update_ratio)
        # ================================================== #

        # ================= Policy Update ================= #
        with torch.no_grad():
            entropy_alpha_tensor = self.getEntropyAlpha()
            epsilons_tensor = torch.randn_like(actions_tensor)
            if self.cost_dim > 0:
                con_lambdas_tensor = self.con_lambdas(preferences_tensor) # (batch_size, cost_dim)

        self.actor.updateActionDist(states_tensor, preferences_tensor, epsilons_tensor)
        sampled_actions_tensor = self.actor.sample(deterministic=False)[0]
        log_probs_tensor = self.actor.getLogProb() # (batch_size,)
        reward_values_tensor = self.reward_critic_target(states_tensor, sampled_actions_tensor, preferences_tensor).mean(dim=(2, 3)) # (batch_size, reward_dim)
        summed_reward_values_tensor = torch.sum(reward_values_tensor*preferences_tensor, dim=-1) # (batch_size,)
        if self.cost_dim > 0:
            cost_values_tensor = self.cost_critic_target(states_tensor, sampled_actions_tensor, preferences_tensor).mean(dim=(2, 3)) # (batch_size, cost_dim)
            summed_cost_values_tensor = torch.sum(cost_values_tensor*con_lambdas_tensor, dim=-1) # (batch_size,)
            policy_loss = torch.mean(entropy_alpha_tensor*log_probs_tensor - summed_reward_values_tensor + summed_cost_values_tensor)
        else:
            policy_loss = torch.mean(entropy_alpha_tensor*log_probs_tensor - summed_reward_values_tensor)
        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm)
        self.actor_optimizer.step()
        # ================================================= #

        # entropy alpha update
        entropy = -torch.mean(log_probs_tensor)
        if self.con_entropy:
            entropy_alpha_tensor = self.getEntropyAlpha()
            entropy_alpha_loss = torch.mean(entropy_alpha_tensor*(entropy - self.target_entropy).detach())
            self.entropy_alpha_optimizer.zero_grad()
            entropy_alpha_loss.backward()
            self.entropy_alpha_optimizer.step()
            self.log_entropy_alpha.data.copy_(self.log_entropy_alpha.data.clamp(-8.0, 8.0))

        # constraint lambda update
        if self.cost_dim > 0:
            con_lambdas_tensor = self.con_lambdas(preferences_tensor) # (batch_size, cost_dim)
            con_lambdas_loss = torch.sum(con_lambdas_tensor*(self.con_thresholds - cost_values_tensor.detach()), dim=-1).mean()
            self.con_lambdas_optimizer.zero_grad()
            con_lambdas_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.con_lambdas.parameters(), self.max_grad_norm)
            self.con_lambdas_optimizer.step()
            return reward_critic_loss, cost_critic_loss, policy_loss, entropy, entropy_alpha_tensor, con_lambdas_tensor.mean(dim=0), cost_values_tensor.mean(dim=0)
        else:
            return reward_critic_loss, policy_loss, entropy, entropy_alpha_tensor
