from algos.common.critic_preference import CriticSAPreference as Critic
from algos.common.agent_base import AgentBase

from .storage import ReplayBuffer
from .actor import Actor
from utils import cprint

import numpy as np
import torch
import os

EPS = 1e-8

class Agent(AgentBase):
    def __init__(self, args) -> None:
        super().__init__(
            name=args.name,
            device=args.device,
            obs_dim=args.obs_dim,
            action_dim=args.action_dim,
            reward_dim=args.reward_dim,
            norm_obs=args.norm_obs,
            norm_reward=args.norm_reward,
        )

        # base
        self.save_dir = args.save_dir
        self.checkpoint_dir=f'{self.save_dir}/checkpoint'
        self.cost_dim = args.__dict__.get('cost_dim', 0)
        self.preference_dim = self.reward_dim
        self.action_bound_min = args.action_bound_min
        self.action_bound_max = args.action_bound_max
        self.n_steps = args.n_steps
        self.n_envs = args.n_envs

        # for RL
        self.discount_factor = args.discount_factor
        self.actor_lr = args.actor_lr
        self.critic_lr = args.critic_lr
        self.n_actor_update_iters = args.n_actor_update_iters
        self.n_critic_update_iters = args.n_critic_update_iters
        self.batch_size = args.batch_size
        self.max_grad_norm = args.max_grad_norm
        self.len_replay_buffer = args.len_replay_buffer
        self.soft_update_ratio = args.soft_update_ratio
        self.n_HER_samples = args.n_HER_samples
        self.explore_action_noise_scale = args.explore_action_noise_scale
        self.target_action_noise_scale = args.target_action_noise_scale
        self.target_action_noise_clip = args.target_action_noise_clip
        self.dir_angle_loss_coeff = args.dir_angle_loss_coeff
        self.model_cfg = args.model

        # for preference interpolation
        self.key_preferences = []
        for prefer_idx in range(self.preference_dim):
            key_preference = np.zeros(self.preference_dim)
            key_preference[prefer_idx] = 1.0
            self.key_preferences.append(key_preference)
        self.key_preferences.append(np.ones(self.preference_dim)/self.preference_dim)
        self.key_preferences = np.array(self.key_preferences)
        self.key_reward_sums = -np.inf*np.ones((len(self.key_preferences), self.reward_dim))
        self.preferInterp = None

        # for model
        self.actor = Actor(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.action_bound_min, self.action_bound_max, self.model_cfg['actor']).to(self.device)
        self.actor_target = Actor(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.action_bound_min, self.action_bound_max, self.model_cfg['actor']).to(self.device)
        self.reward_critic1 = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        self.reward_critic2 = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        self.reward_critic_target1 = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        self.reward_critic_target2 = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, self.model_cfg['reward_critic']).to(self.device)
        if self.cost_dim > 0:
            self.cost_critic1 = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            self.cost_critic2 = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            self.cost_critic_target1 = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            self.cost_critic_target2 = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, self.model_cfg['cost_critic']).to(self.device)
            
        # for constraints
        if self.cost_dim > 0:
            self.log_con_lambdas = torch.zeros(
                self.cost_dim, dtype=torch.float32, device=self.device, requires_grad=True)
            self.getConLambdas = lambda: torch.exp(self.log_con_lambdas)
            self.con_lambdas_lr = args.con_lambdas_lr
            con_thresholds = np.array(args.con_thresholds)/(1.0 - self.discount_factor)
            assert len(con_thresholds) == self.cost_dim
            self.con_thresholds = torch.tensor(con_thresholds, dtype=torch.float32, device=self.device)

        # for replay buffer
        self.replay_buffer = ReplayBuffer(self.len_replay_buffer, self.batch_size, self.discount_factor, self.device)

        # for optimizers
        self.actor_params = list(self.actor.parameters())
        self.reward_critic_params = list(self.reward_critic1.parameters()) + list(self.reward_critic2.parameters())
        self.actor_optimizer = torch.optim.Adam(self.actor_params, lr=self.actor_lr)
        self.reward_critic_optimizer = torch.optim.Adam(self.reward_critic_params, lr=self.critic_lr)
        if self.cost_dim > 0:
            self.cost_critic_params = list(self.cost_critic1.parameters()) + list(self.cost_critic2.parameters())
            self.cost_critic_optimizer = torch.optim.Adam(self.cost_critic_params, lr=self.critic_lr)
            self.con_lambdas_optimizer = torch.optim.Adam([self.log_con_lambdas], lr=self.con_lambdas_lr)

    """ public functions
    """
    @torch.no_grad()
    def getAction(self, state:np.ndarray, preference:np.ndarray, deterministic:bool) -> np.ndarray:
        state_tensor = torch.tensor(self.obs_rms.normalize(state), dtype=torch.float32, device=self.device)
        preference_tensor = torch.tensor(preference, dtype=torch.float32, device=self.device)

        self.actor.updateActionDist(state_tensor, preference_tensor, self.explore_action_noise_scale)
        norm_action_tensor, unnorm_action_tensor = self.actor.sample(deterministic)

        self.np_state = state.copy()
        self.np_preference = preference.copy()
        self.np_action = norm_action_tensor.detach().cpu().numpy()
        return unnorm_action_tensor.detach().cpu().numpy()

    def step(self, *args) -> None:
        if len(args) == 4:
            np_reward, np_done, np_fail, np_nextState = args
            self.replay_buffer.addTransition(
                self.np_state, self.np_action, self.np_preference, np_reward, np_done, np_fail, np_nextState)
            batch_size, preference_dim = self.np_preference.shape
            for _ in range(self.n_HER_samples):
                np_preference = np.random.dirichlet(np.ones(preference_dim), size=batch_size) # (batch_size, preference_dim)
                self.replay_buffer.addTransition(
                    self.np_state, self.np_action, np_preference, np_reward, np_done, np_fail, np_nextState) 
        elif len(args) == 5:
            np_reward, np_cost, np_done, np_fail, np_nextState = args
            self.replay_buffer.addTransition(
                self.np_state, self.np_action, self.np_preference, np_reward, np_cost, np_done, np_fail, np_nextState)
            batch_size, preference_dim = self.np_preference.shape
            for _ in range(self.n_HER_samples):
                np_preference = np.random.dirichlet(np.ones(preference_dim), size=batch_size) # (batch_size, preference_dim)
                self.replay_buffer.addTransition(
                    self.np_state, self.np_action, np_preference, np_reward, np_cost, np_done, np_fail, np_nextState) 
        else:
            raise ValueError("Invalid number of arguments")

        # update statistics
        if self.norm_obs:
            self.obs_rms.update(self.np_state)
        if self.norm_reward:
            self.reward_rms.update(np_reward)

    def readyToTrain(self):
        return self.replay_buffer.getLen() >= self.batch_size

    def train(self):
        n_total_iters = max(self.n_actor_update_iters, self.n_critic_update_iters)
        for iter_idx in range(n_total_iters):
            with torch.no_grad():
                batch_list = self.replay_buffer.getBatches(self.obs_rms, self.reward_rms)
            if (iter_idx + 1) % int(n_total_iters/self.n_critic_update_iters) == 0:
                if self.cost_dim > 0:
                    reward_critic_loss, cost_critic_loss = self._trainCritic(batch_list)
                else:
                    reward_critic_loss = self._trainCritic(batch_list)
            if (iter_idx + 1) % int(n_total_iters/self.n_actor_update_iters) == 0:
                if self.cost_dim > 0:
                    actor_loss, entropy, con_lambdas_tensor, cost_values_tensor = self._trainActor(batch_list)
                else:
                    actor_loss, entropy = self._trainActor(batch_list)

        results = {
            'reward_critic_loss': reward_critic_loss.item(),
            'actor_loss': actor_loss.item(),
            'entropy': entropy.item(),
        }
        if self.cost_dim > 0:
            results['cost_critic_loss'] = cost_critic_loss.item()
            results['con_lambdas'] = con_lambdas_tensor.detach().cpu().numpy()
            results['cost_values'] = cost_values_tensor.detach().cpu().numpy()
        return results

    def save(self, model_num):
        # save rms
        self.obs_rms.save(self.save_dir, model_num)
        self.reward_rms.save(self.save_dir, model_num)

        # save network models
        save_dict = {
            'actor': self.actor.state_dict(),
            'actor_target': self.actor_target.state_dict(),
            'reward_critic1': self.reward_critic1.state_dict(),
            'reward_critic2': self.reward_critic2.state_dict(),
            'reward_critic_target1': self.reward_critic_target1.state_dict(),
            'reward_critic_target2': self.reward_critic_target2.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
            'reward_critic_optimizer': self.reward_critic_optimizer.state_dict(),            
        }
        if self.cost_dim > 0:
            save_dict.update({
                'cost_critic1': self.cost_critic1.state_dict(),
                'cost_critic2': self.cost_critic2.state_dict(),
                'cost_critic_target1': self.cost_critic_target1.state_dict(),
                'cost_critic_target2': self.cost_critic_target2.state_dict(),
                'cost_critic_optimizer': self.cost_critic_optimizer.state_dict(),
                'log_con_lambdas': self.log_con_lambdas.data,
                'con_lambdas_optimizer': self.con_lambdas_optimizer.state_dict(),
            })
        torch.save(save_dict, f"{self.checkpoint_dir}/model_{model_num}.pt")
        cprint(f'[{self.name}] save success.', bold=True, color="blue")

    def load(self, model_num):
        # load rms
        self.obs_rms.load(self.save_dir, model_num)
        self.reward_rms.load(self.save_dir, model_num)

        # load network models
        if not os.path.isdir(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        checkpoint_file = f"{self.checkpoint_dir}/model_{model_num}.pt"
        if os.path.isfile(checkpoint_file):
            checkpoint = torch.load(checkpoint_file, map_location=self.device)
            self.actor.load_state_dict(checkpoint['actor'])
            self.actor_target.load_state_dict(checkpoint['actor_target'])
            self.reward_critic1.load_state_dict(checkpoint['reward_critic1'])
            self.reward_critic2.load_state_dict(checkpoint['reward_critic2'])
            self.reward_critic_target1.load_state_dict(checkpoint['reward_critic_target1'])
            self.reward_critic_target2.load_state_dict(checkpoint['reward_critic_target2'])
            self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
            self.reward_critic_optimizer.load_state_dict(checkpoint['reward_critic_optimizer'])
            if self.cost_dim > 0:
                self.cost_critic1.load_state_dict(checkpoint['cost_critic1'])
                self.cost_critic2.load_state_dict(checkpoint['cost_critic2'])
                self.cost_critic_target1.load_state_dict(checkpoint['cost_critic_target1'])
                self.cost_critic_target2.load_state_dict(checkpoint['cost_critic_target2'])
                self.cost_critic_optimizer.load_state_dict(checkpoint['cost_critic_optimizer'])
                self.log_con_lambdas.data = checkpoint['log_con_lambdas']
                self.con_lambdas_optimizer.load_state_dict(checkpoint['con_lambdas_optimizer'])
            cprint(f'[{self.name}] load success.', bold=True, color="blue")
            return int(model_num)
        else:
            self.actor.initialize()
            self.reward_critic1.initialize()
            self.reward_critic2.initialize()
            self._softUpdate(self.actor_target, self.actor, 0.0)
            self._softUpdate(self.reward_critic_target1, self.reward_critic1, 0.0)
            self._softUpdate(self.reward_critic_target2, self.reward_critic2, 0.0)
            if self.cost_dim > 0:
                self.cost_critic1.initialize()
                self.cost_critic2.initialize()
                self._softUpdate(self.cost_critic_target1, self.cost_critic1, 0.0)
                self._softUpdate(self.cost_critic_target2, self.cost_critic2, 0.0)
            cprint(f'[{self.name}] load fail.', bold=True, color="red")
            return 0

    """ private functions
    """
    def _softUpdate(self, target, source, polyak):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * polyak + param.data * (1.0 - polyak))

    def _trainCritic(self, batch_list):
        # get batches
        if self.cost_dim > 0:
            states_tensor, actions_tensor, rewards_tensor, costs_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list
        else:
            states_tensor, actions_tensor, rewards_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list
            
        with torch.no_grad():
            # get next actions
            next_actions_tensor = self.actor_target(next_states_tensor, preferences_tensor)
            noise_tensor = torch.clamp(torch.randn_like(next_actions_tensor) * self.target_action_noise_scale, 
                                       -self.target_action_noise_clip, self.target_action_noise_clip)
            next_actions_tensor = torch.clamp(next_actions_tensor + noise_tensor, -1.0, 1.0)

            # calculate critic targets
            next_reward_values1_tensor = self.reward_critic_target1(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, reward_dim)
            next_reward_values2_tensor = self.reward_critic_target2(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, reward_dim)
            reward_values_w1_tensor = torch.bmm(preferences_tensor.unsqueeze(-2), next_reward_values1_tensor.unsqueeze(-1)).squeeze() # (batch_size,)
            reward_values_w2_tensor = torch.bmm(preferences_tensor.unsqueeze(-2), next_reward_values2_tensor.unsqueeze(-1)).squeeze() # (batch_size,)
            min_indices_tensor = torch.min(torch.cat([
                reward_values_w1_tensor.unsqueeze(-1), reward_values_w2_tensor.unsqueeze(-1)], dim=-1), 1)[1].unsqueeze(-1) # (batch_size, 1)
            min_reward_values_tensor = (1.0 - min_indices_tensor)*next_reward_values1_tensor + min_indices_tensor*next_reward_values2_tensor
            reward_targets_tensor = rewards_tensor + self.discount_factor * \
                                    (1.0 - fails_tensor.unsqueeze(-1)) * min_reward_values_tensor # (batch_size, reward_dim)
            if self.cost_dim > 0:
                next_cost_values1_tensor = self.cost_critic_target1(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, cost_dim)
                next_cost_values2_tensor = self.cost_critic_target2(next_states_tensor, next_actions_tensor, preferences_tensor) # (batch_size, cost_dim)
                max_cost_values_tensor = torch.max(next_cost_values1_tensor, next_cost_values2_tensor) # (batch_size, cost_dim)
                cost_targets_tensor = costs_tensor + self.discount_factor * (1.0 - fails_tensor.unsqueeze(-1)) * max_cost_values_tensor

        # get projected preferences
        if self.preferInterp:
            projected_preferences = self.preferInterp(preferences_tensor.detach().cpu().numpy()) # (batch_size, preference_dim)
            projected_preferences_tensor = torch.tensor(projected_preferences, device=self.device, dtype=torch.float32)
            # direction angle loss
            reward_values1_tensor = self.reward_critic1(states_tensor, actions_tensor, preferences_tensor) # (batch_size, reward_dim)
            reward_values2_tensor = self.reward_critic2(states_tensor, actions_tensor, preferences_tensor) # (batch_size, reward_dim)
            dir_angle_loss1 = torch.mean(torch.rad2deg(torch.acos(torch.clamp(
                torch.nn.functional.cosine_similarity(projected_preferences_tensor, reward_values1_tensor), 0, 0.9999))))
            dir_angle_loss2 = torch.mean(torch.rad2deg(torch.acos(torch.clamp(
                torch.nn.functional.cosine_similarity(projected_preferences_tensor, reward_values2_tensor), 0, 0.9999))))
            # reward critic loss
            reward_critic_loss = self.reward_critic1.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor) \
                                + self.reward_critic2.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor) \
                                + dir_angle_loss1 + dir_angle_loss2
        else:
            # reward critic loss
            reward_critic_loss = self.reward_critic1.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor) \
                                + self.reward_critic2.getLoss(states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor)
        self.reward_critic_optimizer.zero_grad()
        reward_critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.reward_critic_params, self.max_grad_norm)
        self.reward_critic_optimizer.step()

        if self.cost_dim > 0:
            cost_critic_loss = self.cost_critic1.getLoss(states_tensor, actions_tensor, preferences_tensor, cost_targets_tensor) \
                                + self.cost_critic2.getLoss(states_tensor, actions_tensor, preferences_tensor, cost_targets_tensor)
            self.cost_critic_optimizer.zero_grad()
            cost_critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.cost_critic_params, self.max_grad_norm)
            self.cost_critic_optimizer.step()

            return reward_critic_loss, cost_critic_loss
        else:
            return reward_critic_loss

    def _trainActor(self, batch_list):
        # get batches
        if self.cost_dim > 0:
            states_tensor, actions_tensor, rewards_tensor, costs_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list
        else:
            states_tensor, actions_tensor, rewards_tensor, preferences_tensor, \
                    dones_tensor, fails_tensor, next_states_tensor = batch_list

        self.actor.updateActionDist(states_tensor, preferences_tensor, self.explore_action_noise_scale)
        entropy = self.actor.getEntropy()
        current_actions_tensor = self.actor.sample(deterministic=True)[0]
        reward_values_tensor = self.reward_critic1(states_tensor, current_actions_tensor, preferences_tensor) # (batch_size, reward_dim)
        if self.cost_dim > 0:
            cost_values_tensor = self.cost_critic1(states_tensor, current_actions_tensor, preferences_tensor).mean(dim=0) # (cost_dim,)

        if self.preferInterp:
            projected_preferences = self.preferInterp(preferences_tensor.detach().cpu().numpy()) # (batch_size, preference_dim)
            projected_preferences_tensor = torch.tensor(projected_preferences, device=self.device, dtype=torch.float32)

            dir_angle_loss = torch.mean(torch.rad2deg(torch.acos(torch.clamp(
                torch.nn.functional.cosine_similarity(projected_preferences_tensor, reward_values_tensor), 0, 0.9999))))
            actor_loss = -torch.mean(torch.bmm(preferences_tensor.unsqueeze(-2), reward_values_tensor.unsqueeze(-1))) \
                            + self.dir_angle_loss_coeff * dir_angle_loss
        else:
            actor_loss = -torch.mean(torch.bmm(preferences_tensor.unsqueeze(-2), reward_values_tensor.unsqueeze(-1)))
        if self.cost_dim > 0:
            with torch.no_grad():
                con_lambdas_tensor = self.getConLambdas()
            actor_loss = (actor_loss + torch.dot(con_lambdas_tensor, cost_values_tensor)) \
                                / (1.0 + torch.sum(con_lambdas_tensor))
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm)
        self.actor_optimizer.step()

        # soft update
        self._softUpdate(self.reward_critic_target1, self.reward_critic1, self.soft_update_ratio)
        self._softUpdate(self.reward_critic_target2, self.reward_critic2, self.soft_update_ratio)
        if self.cost_dim > 0:
            self._softUpdate(self.cost_critic_target1, self.cost_critic1, self.soft_update_ratio)
            self._softUpdate(self.cost_critic_target2, self.cost_critic2, self.soft_update_ratio)
        self._softUpdate(self.actor_target, self.actor, self.soft_update_ratio)

        # constraint lambda update
        if self.cost_dim > 0:
            con_lambdas_tensor = self.getConLambdas()
            con_lambdas_loss = torch.dot(con_lambdas_tensor, self.con_thresholds - cost_values_tensor.detach())
            self.con_lambdas_optimizer.zero_grad()
            con_lambdas_loss.backward()
            self.con_lambdas_optimizer.step()
            self.log_con_lambdas.data.copy_(self.log_con_lambdas.data.clamp(-8.0, 8.0))
            return actor_loss, entropy, con_lambdas_tensor, cost_values_tensor
        else:
            return actor_loss, entropy
