from algos.common.actor_preference import ActorGaussianPreference as ActorGaussian
from algos.common.actor_preference import ActorSquashPreference as ActorSquash
from algos.common.critic_distribution import CriticSAPreferDist as Critic
from algos.common.agent_base import AgentBase
from utils import cprint

from .aggregation import aggregation
from .multiplier import Multiplier
from .storage import ReplayBuffer

from typing import Tuple
import numpy as np
import torch
import os

EPS = 1e-8

class Agent(AgentBase):
    def __init__(self, args) -> None:
        super().__init__(
            name=args.name,
            device=args.device,
            obs_dim=args.obs_dim,
            action_dim=args.action_dim,
            reward_dim=args.reward_dim,
            norm_obs=args.norm_obs,
            norm_reward=args.norm_reward,
        )

        # base
        self.save_dir = args.save_dir
        self.checkpoint_dir = f"{self.save_dir}/checkpoint"
        self.cost_dim = args.__dict__.get('cost_dim', 0)
        self.preference_dim = self.reward_dim
        self.action_bound_min = args.action_bound_min
        self.action_bound_max = args.action_bound_max
        self.n_steps = args.n_steps
        self.n_envs = args.n_envs

        # for RL
        self.discount_factor = args.discount_factor
        self.critic_lr = args.critic_lr
        self.actor_lr = args.actor_lr
        self.batch_size = args.batch_size
        self.n_update_iters = args.n_update_iters
        self.max_grad_norm = args.max_grad_norm
        self.len_replay_buffer = args.len_replay_buffer
        self.soft_update_ratio = args.soft_update_ratio
        self.model_cfg = args.model

        # for model
        assert self.model_cfg['actor']['type'] in ['gaussian', 'squash']
        Actor = ActorGaussian if self.model_cfg['actor']['type'] == 'gaussian' else ActorSquash
        self.actor = Actor(
            self.device, self.obs_dim, self.action_dim, self.action_bound_min, 
            self.action_bound_max, self.reward_dim, self.model_cfg['actor']).to(self.device)
        self.reward_critic = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        self.reward_critic_target = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        if self.cost_dim > 0:
            self.cost_critic = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            self.cost_critic_target = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            
        # for constraints
        if self.cost_dim > 0:
            self.con_lambdas_lr = args.con_lambdas_lr
            self.con_lambdas = Multiplier(self.device, self.preference_dim, self.cost_dim, args.model['multipliers']).to(self.device)
            self.con_lambdas_optimizer = torch.optim.Adam(self.con_lambdas.parameters(), lr=self.con_lambdas_lr)
            con_thresholds = np.array(args.con_thresholds)/(1.0 - self.discount_factor)
            assert len(con_thresholds) == self.cost_dim
            self.con_thresholds = torch.tensor(con_thresholds, dtype=torch.float32, device=self.device)

        # for entropy
        self.con_entropy = args.con_entropy
        if self.con_entropy:
            self.target_entropy = self.action_dim*args.entropy_threshold
            self.log_entropy_alpha = torch.tensor(0.0, requires_grad=True, device=self.device)
            self.getEntropyAlpha = lambda: torch.exp(self.log_entropy_alpha)
            self.ent_alpha_lr = args.ent_alpha_lr
        else:
            self.getEntropyAlpha = lambda: torch.tensor(0.0, device=self.device)

        # for replay buffer
        self.replay_buffer = ReplayBuffer(
            self.device, self.len_replay_buffer, self.discount_factor, self.batch_size, self.n_envs,
            self.obs_dim, self.action_dim, self.preference_dim, self.reward_dim, self.cost_dim)

        # for optimizers
        self.actor_params = list(self.actor.parameters())
        self.reward_critic_params = list(self.reward_critic.parameters())
        self.actor_optimizer = torch.optim.Adam(self.actor_params, lr=self.actor_lr)
        self.reward_critic_optimizer = torch.optim.Adam(self.reward_critic_params, lr=self.critic_lr)
        if self.con_entropy:
            self.entropy_alpha_optimizer = torch.optim.Adam([self.log_entropy_alpha], lr=self.ent_alpha_lr)
        if self.cost_dim > 0:
            self.cost_critic_params = list(self.cost_critic.parameters())
            self.cost_critic_optimizer = torch.optim.Adam(self.cost_critic_params, lr=self.critic_lr)

    """ public functions
    """
    @torch.no_grad()
    def getAction(self, state:np.ndarray, preference:np.ndarray, deterministic:bool) -> np.ndarray:
        state_tensor = torch.tensor(self.obs_rms.normalize(state), dtype=torch.float32, device=self.device)
        preference_tensor = torch.tensor(preference, dtype=torch.float32, device=self.device)
        epsilon_tensor = torch.randn(state_tensor.shape[:-1] + (self.action_dim,), device=self.device)

        self.actor.updateActionDist(state_tensor, preference_tensor, epsilon_tensor)
        norm_action_tensor, unnorm_action_tensor = self.actor.sample(deterministic)

        self.state = state.copy()
        self.preference = preference.copy()
        self.action = norm_action_tensor.detach().cpu().numpy()
        return unnorm_action_tensor.detach().cpu().numpy()

    def step(self, *args) -> None:
        if len(args) == 4:
            reward_vecs, dones, fails, next_states = args
            self.replay_buffer.addTransition(
                self.state, self.action, self.preference, reward_vecs, dones, fails, next_states)
        elif len(args) == 5:
            reward_vecs, cost_vecs, dones, fails, next_states = args
            self.replay_buffer.addTransition(
                self.state, self.action, self.preference, reward_vecs, cost_vecs, dones, fails, next_states)
        else:
            raise ValueError("Invalid number of arguments")

        # update statistics
        if self.norm_obs:
            self.obs_rms.update(self.state)
        if self.norm_reward:
            self.reward_rms.update(reward_vecs)

    def readyToTrain(self):
        return self.replay_buffer.getLen() >= self.batch_size

    def train(self):
        for _ in range(self.n_update_iters):
            if self.cost_dim > 0:
                reward_critic_loss, cost_critic_loss, policy_loss, entropy, entropy_alpha_tensor, con_lambdas_tensor, cost_values_tensor = self._train()
            else:
                reward_critic_loss, policy_loss, entropy, entropy_alpha_tensor = self._train()
        results = {
            'reward_critic_loss': reward_critic_loss.item(),
            'policy_loss': policy_loss.item(),
            'entropy': entropy.item(),
            'ent_alpha': entropy_alpha_tensor.item(),
        }
        if self.cost_dim > 0:
            results.update({
                'cost_critic_loss': cost_critic_loss.item(),
                'con_lambdas': con_lambdas_tensor.detach().cpu().numpy(),
                'cost_values': cost_values_tensor.detach().cpu().numpy(),
            })
        return results

    def save(self, model_num):
        # save rms
        self.obs_rms.save(self.save_dir, model_num)
        self.reward_rms.save(self.save_dir, model_num)

        # save network models
        save_dict = {
            'actor': self.actor.state_dict(),
            'reward_critic': self.reward_critic.state_dict(),
            'reward_critic_target': self.reward_critic_target.state_dict(),
            'reward_critic_optimizer': self.reward_critic_optimizer.state_dict(),            
        }
        if self.con_entropy:
            save_dict.update({
                'log_entropy_alpha': self.log_entropy_alpha.data,
                'entropy_alpha_optimizer': self.entropy_alpha_optimizer.state_dict(),
            })
        if self.cost_dim > 0:
            save_dict.update({
                'cost_critic': self.cost_critic.state_dict(),
                'cost_critic_target': self.cost_critic_target.state_dict(),
                'cost_critic_optimizer': self.cost_critic_optimizer.state_dict(),
                'con_lambdas': self.con_lambdas.state_dict(),
                'con_lambdas_optimizer': self.con_lambdas_optimizer.state_dict(),
            })
        torch.save(save_dict, f"{self.checkpoint_dir}/model_{model_num}.pt")
        cprint(f'[{self.name}] save success.', bold=True, color="blue")

    def load(self, model_num):
        # load rms
        self.obs_rms.load(self.save_dir, model_num)
        self.reward_rms.load(self.save_dir, model_num)

        # load network models
        if not os.path.isdir(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        checkpoint_file = f"{self.checkpoint_dir}/model_{model_num}.pt"
        if os.path.isfile(checkpoint_file):
            checkpoint = torch.load(checkpoint_file, map_location=self.device)
            self.actor.load_state_dict(checkpoint['actor'])
            self.reward_critic.load_state_dict(checkpoint['reward_critic'])
            self.reward_critic_target.load_state_dict(checkpoint['reward_critic_target'])
            self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
            self.reward_critic_optimizer.load_state_dict(checkpoint['reward_critic_optimizer'])
            if self.con_entropy:
                self.log_entropy_alpha.data = checkpoint['log_entropy_alpha']
                self.entropy_alpha_optimizer.load_state_dict(checkpoint['entropy_alpha_optimizer'])
            if self.cost_dim > 0:
                self.cost_critic.load_state_dict(checkpoint['cost_critic'])
                self.cost_critic_target.load_state_dict(checkpoint['cost_critic_target'])
                self.cost_critic_optimizer.load_state_dict(checkpoint['cost_critic_optimizer'])
                self.con_lambdas.load_state_dict(checkpoint['con_lambdas'])
                self.con_lambdas_optimizer.load_state_dict(checkpoint['con_lambdas_optimizer'])
            cprint(f'[{self.name}] load success.', bold=True, color="blue")
            return int(model_num)
        else:
            self.actor.initialize()
            self.reward_critic.initialize()
            self._softUpdate(self.reward_critic_target, self.reward_critic, 0.0)
            if self.cost_dim > 0:
                self.con_lambdas.initialize()
                self.cost_critic.initialize()
                self._softUpdate(self.cost_critic_target, self.cost_critic, 0.0)
            cprint(f'[{self.name}] load fail.', bold=True, color="red")
            return 0

    """ private functions
    """
    def _softUpdate(self, target, source, polyak):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * polyak + param.data * (1.0 - polyak))

    def _train(self):
        # get batches
        with torch.no_grad():
            if self.cost_dim > 0:
                states_tensor, actions_tensor, rewards_tensor, costs_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = self.replay_buffer.getBatches(self.obs_rms, self.reward_rms)
            else:
                states_tensor, actions_tensor, rewards_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = self.replay_buffer.getBatches(self.obs_rms, self.reward_rms)

        # ================== Critic Update ================== #
        # calculate critic targets
        with torch.no_grad():
            batch_size = states_tensor.shape[0]
            entropy_alpha_tensor = self.getEntropyAlpha()            
            epsilons_tensor = torch.randn_like(actions_tensor)
            self.actor.updateActionDist(next_states_tensor, preferences_tensor, epsilons_tensor)
            next_actions_tensor = self.actor.sample(deterministic=False)[0] # (batch_size, action_dim)
            next_actions_tensor = torch.clamp(next_actions_tensor, 
                                              torch.tensor(self.action_bound_min, dtype=torch.float32, device=self.device), 
                                              torch.tensor(self.action_bound_max, dtype=torch.float32, device=self.device)) 
            next_log_probs_tensor = self.actor.getLogProb() # (batch_size,)
            next_reward_quantiles_tensor = self.reward_critic_target(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, reward_dim, n_critics, n_quantiles)
            next_reward_quantiles_tensor = torch.sort(next_reward_quantiles_tensor.view(batch_size, self.reward_dim, -1), dim=-1)[0] # (batch_size, reward_dim, n_critic*n_quantiles)
            next_reward_quantiles_tensor = next_reward_quantiles_tensor[:, :, :-2]
            reward_targets_tensor = rewards_tensor.view(batch_size, self.reward_dim, 1) + \
                    self.discount_factor*(1.0 - fails_tensor.view(batch_size, 1, 1)) * \
                    (next_reward_quantiles_tensor - entropy_alpha_tensor*next_log_probs_tensor.view(batch_size, 1, 1)) # (batch_size, reward_dim, n_critic*n_quantiles)
            if self.cost_dim > 0:
                next_cost_quantiles_tensor = self.cost_critic_target(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, cost_dim, n_critics, n_quantiles)
                next_cost_quantiles_tensor = torch.sort(next_cost_quantiles_tensor.view(batch_size, self.cost_dim, -1), dim=-1)[0] # (batch_size, cost_dim, n_critic*n_quantiles)
                next_cost_quantiles_tensor = next_cost_quantiles_tensor[:, :, 2:]
                cost_targets_tensor = costs_tensor.view(batch_size, self.cost_dim, 1) + \
                        self.discount_factor*(1.0 - fails_tensor.view(batch_size, 1, 1)) * next_cost_quantiles_tensor # (batch_size, cost_dim, n_critic*n_quantiles)

        # reward critic update
        reward_critic_loss = self.reward_critic.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor)
        self.reward_critic_optimizer.zero_grad()
        reward_critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.reward_critic_params, self.max_grad_norm)
        self.reward_critic_optimizer.step()

        # soft update
        self._softUpdate(self.reward_critic_target, self.reward_critic, self.soft_update_ratio)

        if self.cost_dim > 0:
            # cost critic update
            cost_critic_loss = self.cost_critic.getLoss(states_tensor, actions_tensor, preferences_tensor, cost_targets_tensor)
            self.cost_critic_optimizer.zero_grad()
            cost_critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.cost_critic_params, self.max_grad_norm)
            self.cost_critic_optimizer.step()

            # soft update
            self._softUpdate(self.cost_critic_target, self.cost_critic, self.soft_update_ratio)
        # ================================================== #

        # ================= Policy Update ================= #
        with torch.no_grad():
            entropy_alpha_tensor = self.getEntropyAlpha()
            epsilons_tensor = torch.randn_like(actions_tensor)
            if self.cost_dim > 0:
                con_lambdas_tensor = self.con_lambdas(preferences_tensor) # (batch_size, cost_dim)

        self.actor.updateActionDist(states_tensor, preferences_tensor, epsilons_tensor)
        sampled_actions_tensor = self.actor.sample(deterministic=False)[0]
        log_probs_tensor = self.actor.getLogProb() # (batch_size,)
        entropy = -torch.mean(log_probs_tensor)
        objectives_tensor = self.reward_critic_target(states_tensor, sampled_actions_tensor, preferences_tensor).mean(dim=(0, 2, 3)) # (reward_dim,)
        objectives_tensor += entropy_alpha_tensor*entropy
        aggregated_objective_tensor = aggregation(objectives_tensor, preferences_tensor[0], self.actor, self.device) # (,)
        if self.cost_dim > 0:
            cost_values_tensor = self.cost_critic_target(states_tensor, sampled_actions_tensor, preferences_tensor).mean(dim=(2, 3)) # (batch_size, cost_dim)
            summed_cost_values_tensor = torch.sum(cost_values_tensor*con_lambdas_tensor, dim=-1) # (batch_size,)
            policy_loss = (-aggregated_objective_tensor + torch.mean(summed_cost_values_tensor))/(1.0 + con_lambdas_tensor.sum(dim=-1).mean())
        else:
            policy_loss = -aggregated_objective_tensor
        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm)
        self.actor_optimizer.step()
        # ================================================= #

        # entropy alpha update
        if self.con_entropy:
            entropy_alpha_tensor = self.getEntropyAlpha()
            entropy_alpha_loss = torch.mean(entropy_alpha_tensor*(entropy - self.target_entropy).detach())
            self.entropy_alpha_optimizer.zero_grad()
            entropy_alpha_loss.backward()
            self.entropy_alpha_optimizer.step()
            self.log_entropy_alpha.data.copy_(self.log_entropy_alpha.data.clamp(-8.0, 8.0))

        # constraint lambda update
        if self.cost_dim > 0:
            con_lambdas_tensor = self.con_lambdas(preferences_tensor) # (batch_size, cost_dim)
            con_lambdas_loss = torch.sum(con_lambdas_tensor*(self.con_thresholds - cost_values_tensor.detach()), dim=-1).mean()
            self.con_lambdas_optimizer.zero_grad()
            con_lambdas_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.con_lambdas.parameters(), self.max_grad_norm)
            self.con_lambdas_optimizer.step()
            return reward_critic_loss, cost_critic_loss, policy_loss, entropy, entropy_alpha_tensor, con_lambdas_tensor.mean(dim=0), cost_values_tensor.mean(dim=0)
        else:
            return reward_critic_loss, policy_loss, entropy, entropy_alpha_tensor
