from algos.common.critic_distribution import CriticSAPreferDist as Critic
from algos.common.agent_base import AgentBase
from utils import cprint

from .multiplier import Multiplier
from .storage import ReplayBuffer
from .actor import Actor

import numpy as np
import torch
import os

EPS = 1e-8

class Agent(AgentBase):
    def __init__(self, args) -> None:
        super().__init__(
            name=args.name,
            device=args.device,
            obs_dim=args.obs_dim,
            action_dim=args.action_dim,
            reward_dim=args.reward_dim,
            norm_obs=args.norm_obs,
            norm_reward=args.norm_reward,
        )

        # base
        self.save_dir = args.save_dir
        self.checkpoint_dir = f"{self.save_dir}/checkpoint"
        self.cost_dim = args.__dict__.get('cost_dim', 0)
        self.preference_dim = self.reward_dim
        self.action_bound_min = args.action_bound_min
        self.action_bound_max = args.action_bound_max
        self.n_steps = args.n_steps
        self.n_envs = args.n_envs

        # for RL
        self.discount_factor = args.discount_factor
        self.critic_lr = args.critic_lr
        self.actor_lr = args.actor_lr
        self.batch_size = args.batch_size
        self.n_actor_update_iters = args.n_actor_update_iters
        self.n_critic_update_iters = args.n_critic_update_iters
        self.max_grad_norm = args.max_grad_norm
        self.len_replay_buffer = args.len_replay_buffer
        self.soft_update_ratio = args.soft_update_ratio
        self.n_HER_samples = args.n_HER_samples
        self.explore_action_noise_scale = args.explore_action_noise_scale
        self.target_action_noise_scale = args.target_action_noise_scale
        self.target_action_noise_clip = args.target_action_noise_clip
        self.dir_angle_loss_coeff = args.dir_angle_loss_coeff
        self.model_cfg = args.model

        # for preference interpolation
        self.key_preferences = []
        for prefer_idx in range(self.preference_dim):
            key_preference = np.zeros(self.preference_dim)
            key_preference[prefer_idx] = 1.0
            self.key_preferences.append(key_preference)
        self.key_preferences.append(np.ones(self.preference_dim)/self.preference_dim)
        self.key_preferences = np.array(self.key_preferences)
        self.key_reward_sums = -np.inf*np.ones((len(self.key_preferences), self.reward_dim))
        self.preferInterp = None

        # for model
        self.actor = Actor(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.action_bound_min, self.action_bound_max, self.model_cfg['actor']).to(self.device)
        self.actor_target = Actor(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.action_bound_min, self.action_bound_max, self.model_cfg['actor']).to(self.device)
        self.reward_critic = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        self.reward_critic_target = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        if self.cost_dim > 0:
            self.cost_critic = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            self.cost_critic_target = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            
        # for constraints
        if self.cost_dim > 0:
            self.con_lambdas_lr = args.con_lambdas_lr
            self.con_lambdas = Multiplier(self.device, self.preference_dim, self.cost_dim, args.model['multipliers']).to(self.device)
            self.con_lambdas_optimizer = torch.optim.Adam(self.con_lambdas.parameters(), lr=self.con_lambdas_lr)
            con_thresholds = np.array(args.con_thresholds)/(1.0 - self.discount_factor)
            assert len(con_thresholds) == self.cost_dim
            self.con_thresholds = torch.tensor(con_thresholds, dtype=torch.float32, device=self.device)

        # for replay buffer
        self.replay_buffer = ReplayBuffer(self.len_replay_buffer, self.batch_size, self.discount_factor, self.device)

        # for optimizers
        self.actor_params = list(self.actor.parameters())
        self.reward_critic_params = list(self.reward_critic.parameters())
        self.actor_optimizer = torch.optim.Adam(self.actor_params, lr=self.actor_lr)
        self.reward_critic_optimizer = torch.optim.Adam(self.reward_critic_params, lr=self.critic_lr)
        if self.cost_dim > 0:
            self.cost_critic_params = list(self.cost_critic.parameters())
            self.cost_critic_optimizer = torch.optim.Adam(self.cost_critic_params, lr=self.critic_lr)

    """ public functions
    """
    @torch.no_grad()
    def getAction(self, state:np.ndarray, preference:np.ndarray, deterministic:bool) -> np.ndarray:
        state_tensor = torch.tensor(self.obs_rms.normalize(state), dtype=torch.float32, device=self.device)
        preference_tensor = torch.tensor(preference, dtype=torch.float32, device=self.device)

        self.actor.updateActionDist(state_tensor, preference_tensor, self.explore_action_noise_scale)
        norm_action_tensor, unnorm_action_tensor = self.actor.sample(deterministic)

        self.np_state = state.copy()
        self.np_preference = preference.copy()
        self.np_action = norm_action_tensor.detach().cpu().numpy()
        return unnorm_action_tensor.detach().cpu().numpy()

    def step(self, *args) -> None:
        if len(args) == 4:
            np_reward, np_done, np_fail, np_nextState = args
            self.replay_buffer.addTransition(
                self.np_state, self.np_action, self.np_preference, np_reward, np_done, np_fail, np_nextState)
            batch_size, preference_dim = self.np_preference.shape
            for _ in range(self.n_HER_samples):
                np_preference = np.random.dirichlet(np.ones(preference_dim), size=batch_size) # (batch_size, preference_dim)
                self.replay_buffer.addTransition(
                    self.np_state, self.np_action, np_preference, np_reward, np_done, np_fail, np_nextState) 
        elif len(args) == 5:
            np_reward, np_cost, np_done, np_fail, np_nextState = args
            self.replay_buffer.addTransition(
                self.np_state, self.np_action, self.np_preference, np_reward, np_cost, np_done, np_fail, np_nextState)
            batch_size, preference_dim = self.np_preference.shape
            for _ in range(self.n_HER_samples):
                np_preference = np.random.dirichlet(np.ones(preference_dim), size=batch_size) # (batch_size, preference_dim)
                self.replay_buffer.addTransition(
                    self.np_state, self.np_action, np_preference, np_reward, np_cost, np_done, np_fail, np_nextState) 
        else:
            raise ValueError("Invalid number of arguments")

        # update statistics
        if self.norm_obs:
            self.obs_rms.update(self.np_state)
        if self.norm_reward:
            self.reward_rms.update(np_reward)

    def readyToTrain(self):
        return self.replay_buffer.getLen() >= self.batch_size

    def train(self):
        n_total_iters = max(self.n_actor_update_iters, self.n_critic_update_iters)
        for iter_idx in range(n_total_iters):
            with torch.no_grad():
                batch_list = self.replay_buffer.getBatches(self.obs_rms, self.reward_rms)
            if (iter_idx + 1) % int(n_total_iters/self.n_critic_update_iters) == 0:
                if self.cost_dim > 0:
                    reward_critic_loss, cost_critic_loss = self._trainCritic(batch_list)
                else:
                    reward_critic_loss = self._trainCritic(batch_list)
            if (iter_idx + 1) % int(n_total_iters/self.n_actor_update_iters) == 0:
                if self.cost_dim > 0:
                    actor_loss, entropy, con_lambdas_tensor, cost_values_tensor = self._trainActor(batch_list)
                else:
                    actor_loss, entropy = self._trainActor(batch_list)

        results = {
            'reward_critic_loss': reward_critic_loss.item(),
            'actor_loss': actor_loss.item(),
            'entropy': entropy.item(),
        }
        if self.cost_dim > 0:
            results['cost_critic_loss'] = cost_critic_loss.item()
            results['con_lambdas'] = con_lambdas_tensor.detach().cpu().numpy()
            results['cost_values'] = cost_values_tensor.detach().cpu().numpy()
        return results

    def save(self, model_num):
        # save rms
        self.obs_rms.save(self.save_dir, model_num)
        self.reward_rms.save(self.save_dir, model_num)

        # save network models
        save_dict = {
            'actor': self.actor.state_dict(),
            'actor_target': self.actor_target.state_dict(),
            'reward_critic': self.reward_critic.state_dict(),
            'reward_critic_target': self.reward_critic_target.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
            'reward_critic_optimizer': self.reward_critic_optimizer.state_dict(),            
        }
        if self.cost_dim > 0:
            save_dict.update({
                'cost_critic': self.cost_critic.state_dict(),
                'cost_critic_target': self.cost_critic_target.state_dict(),
                'cost_critic_optimizer': self.cost_critic_optimizer.state_dict(),
                'con_lambdas': self.con_lambdas.state_dict(),
                'con_lambdas_optimizer': self.con_lambdas_optimizer.state_dict(),
            })
        torch.save(save_dict, f"{self.checkpoint_dir}/model_{model_num}.pt")
        cprint(f'[{self.name}] save success.', bold=True, color="blue")

    def load(self, model_num):
        # load rms
        self.obs_rms.load(self.save_dir, model_num)
        self.reward_rms.load(self.save_dir, model_num)

        # load network models
        if not os.path.isdir(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        checkpoint_file = f"{self.checkpoint_dir}/model_{model_num}.pt"
        if os.path.isfile(checkpoint_file):
            checkpoint = torch.load(checkpoint_file, map_location=self.device)
            self.actor.load_state_dict(checkpoint['actor'])
            self.actor_target.load_state_dict(checkpoint['actor_target'])
            self.reward_critic.load_state_dict(checkpoint['reward_critic'])
            self.reward_critic_target.load_state_dict(checkpoint['reward_critic_target'])
            self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
            self.reward_critic_optimizer.load_state_dict(checkpoint['reward_critic_optimizer'])
            if self.cost_dim > 0:
                self.cost_critic.load_state_dict(checkpoint['cost_critic'])
                self.cost_critic_target.load_state_dict(checkpoint['cost_critic_target'])
                self.cost_critic_optimizer.load_state_dict(checkpoint['cost_critic_optimizer'])
                self.con_lambdas.load_state_dict(checkpoint['con_lambdas'])
                self.con_lambdas_optimizer.load_state_dict(checkpoint['con_lambdas_optimizer'])
            cprint(f'[{self.name}] load success.', bold=True, color="blue")
            return int(model_num)
        else:
            self.actor.initialize()
            self.reward_critic.initialize()
            self._softUpdate(self.actor_target, self.actor, 0.0)
            self._softUpdate(self.reward_critic_target, self.reward_critic, 0.0)
            if self.cost_dim > 0:
                self.con_lambdas.initialize()
                self.cost_critic.initialize()
                self._softUpdate(self.cost_critic_target, self.cost_critic, 0.0)
            cprint(f'[{self.name}] load fail.', bold=True, color="red")
            return 0

    """ private functions
    """
    def _softUpdate(self, target, source, polyak):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * polyak + param.data * (1.0 - polyak))

    def _trainCritic(self, batch_list):
        # get batches
        if self.cost_dim > 0:
            states_tensor, actions_tensor, rewards_tensor, costs_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list
        else:
            states_tensor, actions_tensor, rewards_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list
            
        with torch.no_grad():
            # get next actions
            batch_size = states_tensor.shape[0]
            next_actions_tensor = self.actor_target(next_states_tensor, preferences_tensor)
            noise_tensor = torch.clamp(torch.randn_like(next_actions_tensor) * self.target_action_noise_scale, 
                                       -self.target_action_noise_clip, self.target_action_noise_clip)
            next_actions_tensor = torch.clamp(next_actions_tensor + noise_tensor, -1.0, 1.0)

            # calculate critic targets
            next_reward_quantiles_tensor = self.reward_critic_target(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, reward_dim, n_critics, n_quantiles)
            next_reward_quantiles_tensor = torch.sort(next_reward_quantiles_tensor.view(batch_size, self.reward_dim, -1), dim=-1)[0] # (batch_size, reward_dim, n_critic*n_quantiles)
            next_reward_quantiles_tensor = next_reward_quantiles_tensor[:, :, :-2]
            reward_targets_tensor = rewards_tensor.view(batch_size, self.reward_dim, 1) + \
                    self.discount_factor*(1.0 - fails_tensor.view(batch_size, 1, 1)) * next_reward_quantiles_tensor # (batch_size, reward_dim, n_critic*n_quantiles)
            if self.cost_dim > 0:
                next_cost_quantiles_tensor = self.cost_critic_target(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, cost_dim, n_critics, n_quantiles)
                next_cost_quantiles_tensor = torch.sort(next_cost_quantiles_tensor.view(batch_size, self.cost_dim, -1), dim=-1)[0] # (batch_size, cost_dim, n_critic*n_quantiles)
                next_cost_quantiles_tensor = next_cost_quantiles_tensor[:, :, 2:]
                cost_targets_tensor = costs_tensor.view(batch_size, self.cost_dim, 1) + \
                        self.discount_factor*(1.0 - fails_tensor.view(batch_size, 1, 1)) * next_cost_quantiles_tensor # (batch_size, cost_dim, n_critic*n_quantiles)

        # get projected preferences
        if self.preferInterp:
            projected_preferences = self.preferInterp(preferences_tensor.detach().cpu().numpy()) # (batch_size, preference_dim)
            projected_preferences_tensor = torch.tensor(projected_preferences, device=self.device, dtype=torch.float32)
            # direction angle loss
            reward_values_tensor = self.reward_critic(states_tensor, actions_tensor, preferences_tensor).mean(dim=(2, 3)) # (batch_size, reward_dim)
            dir_angle_loss = torch.mean(torch.rad2deg(torch.acos(torch.clamp(
                torch.nn.functional.cosine_similarity(projected_preferences_tensor, reward_values_tensor), 0, 0.9999))))
            # reward critic loss
            reward_critic_loss = self.reward_critic.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor) + dir_angle_loss
        else:
            # reward critic loss
            reward_critic_loss = self.reward_critic.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor)
        self.reward_critic_optimizer.zero_grad()
        reward_critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.reward_critic_params, self.max_grad_norm)
        self.reward_critic_optimizer.step()

        if self.cost_dim > 0:
            cost_critic_loss = self.cost_critic.getLoss(states_tensor, actions_tensor, preferences_tensor, cost_targets_tensor)
            self.cost_critic_optimizer.zero_grad()
            cost_critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.cost_critic_params, self.max_grad_norm)
            self.cost_critic_optimizer.step()

            return reward_critic_loss, cost_critic_loss
        else:
            return reward_critic_loss

    def _trainActor(self, batch_list):
        # get batches
        if self.cost_dim > 0:
            states_tensor, actions_tensor, rewards_tensor, costs_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list
        else:
            states_tensor, actions_tensor, rewards_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list

        self.actor.updateActionDist(states_tensor, preferences_tensor, self.explore_action_noise_scale)
        entropy = self.actor.getEntropy()
        current_actions_tensor = self.actor.sample(deterministic=True)[0]
        reward_values_tensor = self.reward_critic(states_tensor, current_actions_tensor, preferences_tensor).mean(dim=(2, 3)) # (batch_size, reward_dim)
        if self.cost_dim > 0:
            cost_values_tensor = self.cost_critic(states_tensor, current_actions_tensor, preferences_tensor).mean(dim=(2, 3)) # (batch_size, cost_dim)

        if self.preferInterp:
            projected_preferences = self.preferInterp(preferences_tensor.detach().cpu().numpy()) # (batch_size, preference_dim)
            projected_preferences_tensor = torch.tensor(projected_preferences, device=self.device, dtype=torch.float32)

            dir_angle_loss = torch.mean(torch.rad2deg(torch.acos(torch.clamp(
                torch.nn.functional.cosine_similarity(projected_preferences_tensor, reward_values_tensor), 0, 0.9999))))
            actor_loss = -torch.mean(torch.bmm(preferences_tensor.unsqueeze(-2), reward_values_tensor.unsqueeze(-1))) \
                            + self.dir_angle_loss_coeff * dir_angle_loss
        else:
            actor_loss = -torch.mean(torch.bmm(preferences_tensor.unsqueeze(-2), reward_values_tensor.unsqueeze(-1)))
        if self.cost_dim > 0:
            with torch.no_grad():
                con_lambdas_tensor = self.con_lambdas(preferences_tensor) # (batch_size, cost_dim)
            actor_loss = actor_loss + torch.sum(cost_values_tensor * con_lambdas_tensor, dim=-1).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm)
        self.actor_optimizer.step()

        # soft update
        self._softUpdate(self.reward_critic_target, self.reward_critic, self.soft_update_ratio)
        if self.cost_dim > 0:
            self._softUpdate(self.cost_critic_target, self.cost_critic, self.soft_update_ratio)
        self._softUpdate(self.actor_target, self.actor, self.soft_update_ratio)

        # constraint lambda update
        if self.cost_dim > 0:
            con_lambdas_tensor = self.con_lambdas(preferences_tensor) # (batch_size, cost_dim)
            con_lambdas_loss = torch.sum(con_lambdas_tensor*(self.con_thresholds - cost_values_tensor.detach()), dim=-1).mean()
            self.con_lambdas_optimizer.zero_grad()
            con_lambdas_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.con_lambdas.parameters(), self.max_grad_norm)
            self.con_lambdas_optimizer.step()
            return actor_loss, entropy, con_lambdas_tensor.mean(dim=0), cost_values_tensor.mean(dim=0)
        else:
            return actor_loss, entropy
