from algos.common.actor_preference import ActorGaussianPreference as Actor
from algos.common.critic_preference import CriticSAPreference as Critic
from algos.common.optimizer_tr import TROptimizer
from algos.common.agent_base import AgentBase
from utils import cprint

from .storage import ReplayBuffer

from scipy.optimize import minimize
import numpy as np
import torch
import os

EPS = 1e-8

class Agent(AgentBase):
    def __init__(self, args) -> None:
        super().__init__(
            name=args.name,
            device=args.device,
            obs_dim=args.obs_dim,
            action_dim=args.action_dim,
            reward_dim=args.reward_dim,
            norm_obs=args.norm_obs,
            norm_reward=args.norm_reward,
        )

        # for base
        self.save_dir = args.save_dir
        self.checkpoint_dir = f"{self.save_dir}/checkpoint"
        self.cost_dim = args.__dict__.get('cost_dim', 0)
        self.preference_dim = self.reward_dim
        self.action_bound_min = args.action_bound_min
        self.action_bound_max = args.action_bound_max
        self.n_steps = args.n_steps
        self.n_envs = args.n_envs

        # for RL
        self.discount_factor = args.discount_factor
        self.n_update_steps = args.n_update_steps
        self.critic_lr = args.critic_lr
        self.n_critic_iters = args.n_critic_iters
        self.len_replay_buffer = args.len_replay_buffer
        self.gae_coeff = args.gae_coeff
        self.max_grad_norm = args.max_grad_norm
        self.n_action_samples = args.n_action_samples

        # for trust region
        self.damping_coeff = args.damping_coeff
        self.num_conjugate = args.num_conjugate
        self.line_decay = args.line_decay
        self.max_kl = args.max_kl

        # for constraints
        if self.cost_dim > 0:
            self.con_thresholds = torch.tensor(
                np.array(args.con_thresholds)/(1.0 - self.discount_factor), dtype=torch.float32, device=self.device)
            assert len(self.con_thresholds) == self.cost_dim

        # for preferences
        self.max_reward_preference = args.max_reward_preference
        assert self.max_reward_preference > 0
        if self.cost_dim > 0:
            self.cost_preferences = torch.tensor(np.array(args.init_cost_preferences), dtype=torch.float32, device=self.device)
            self.max_cost_preference = args.max_cost_preference
            assert self.max_cost_preference > 0
        self.n_prefer_samples = args.n_prefer_samples
        self.preference_lr = args.preference_lr

        # for models
        model_cfg = args.model
        self.actor = Actor(
            self.device, self.obs_dim, self.action_dim, self.action_bound_min, 
            self.action_bound_max, self.reward_dim, model_cfg['actor']).to(self.device)
        self.reward_critic = Critic(
            self.device, self.obs_dim, self.action_dim, self.reward_dim, 
            self.preference_dim, model_cfg['reward_critic']).to(self.device)
        if self.cost_dim > 0:
            self.cost_critic = Critic(
                self.device, self.obs_dim, self.action_dim, self.cost_dim, 
                self.preference_dim, model_cfg['cost_critic']).to(self.device)

        # for temperatures
        def dual_problem(temperature, preference, sampled_values):
            max_values = np.max(sampled_values, axis=0)
            logmeanexp = np.log(np.mean(np.exp((sampled_values - max_values.reshape(1, -1)) / temperature), axis=0))
            return temperature * preference + np.mean(max_values) + temperature * np.mean(logmeanexp)
        self.dual_problem = dual_problem

        # for optimizers
        self.actor_optimizer = TROptimizer(
            self.actor, self.damping_coeff, self.num_conjugate, self.line_decay, self.max_kl, self.device)
        self.reward_critic_optimizer = torch.optim.Adam(self.reward_critic.parameters(), lr=self.critic_lr)
        if self.cost_dim > 0:
            self.cost_critic_optimizer = torch.optim.Adam(self.cost_critic.parameters(), lr=self.critic_lr)

        # for replay buffer
        self.replay_buffer = ReplayBuffer(
            self.device, self.len_replay_buffer, self.discount_factor, 
            self.gae_coeff, self.n_envs, self.n_steps, self.n_update_steps)

    ################
    # public methods
    ################
 
    @torch.no_grad()
    def getAction(self, state:np.ndarray, preference:np.ndarray, deterministic:bool) -> np.ndarray:
        state_tensor = torch.tensor(self.obs_rms.normalize(state), dtype=torch.float32, device=self.device)
        preference_tensor = torch.tensor(preference, dtype=torch.float32, device=self.device)
        epsilon_tensor = torch.randn(state_tensor.shape[:-1] + (self.action_dim,), device=self.device)

        self.actor.updateActionDist(state_tensor, preference_tensor, epsilon_tensor)
        norm_action_tensor, unnorm_action_tensor = self.actor.sample(deterministic)
        log_prob_tensor = self.actor.getLogProb()

        self.state = state.copy()
        self.preference = preference.copy()
        self.action = norm_action_tensor.detach().cpu().numpy()
        self.log_prob = log_prob_tensor.detach().cpu().numpy()
        return unnorm_action_tensor.detach().cpu().numpy()

    def step(self, *args):
        if len(args) == 4:
            reward_vecs, dones, fails, next_states = args
            self.replay_buffer.addTransition(self.state, self.action, self.preference, \
                                    self.log_prob, reward_vecs, dones, fails, next_states)
        elif len(args) == 5:
            reward_vecs, cost_vecs, dones, fails, next_states = args
            self.replay_buffer.addTransition(self.state, self.action, self.preference, \
                                    self.log_prob, reward_vecs, cost_vecs, dones, fails, next_states)
        else:
            raise NotImplementedError

        # update statistics
        if self.norm_obs:
            self.obs_rms.update(self.state)
        if self.norm_reward:
            self.reward_rms.update(reward_vecs)

    def readyToTrain(self):
        return self.replay_buffer.getLen() > 0

    def train(self):
        # convert to tensor
        with torch.no_grad():
            if self.cost_dim > 0:
                states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor, cost_targets_tensor = \
                    self.replay_buffer.getBatches(self.obs_rms, self.reward_rms, self.actor, self.reward_critic, self.cost_critic)
            else:
                states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor = \
                    self.replay_buffer.getBatches(self.obs_rms, self.reward_rms, self.actor, self.reward_critic)

        # ======================= update critic ======================= #
        for _ in range(self.n_critic_iters):
            # for reward
            reward_critic_loss = self.reward_critic.getLoss(
                states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor)
            self.reward_critic_optimizer.zero_grad()
            reward_critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.reward_critic.parameters(), self.max_grad_norm)
            self.reward_critic_optimizer.step()

            # for cost
            if self.cost_dim > 0:
                cost_critic_loss = self.cost_critic.getLoss(
                    states_tensor, actions_tensor, preferences_tensor, cost_targets_tensor)
                self.cost_critic_optimizer.zero_grad()
                cost_critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.cost_critic.parameters(), self.max_grad_norm)
                self.cost_critic_optimizer.step()
        # ============================================================= #

        states_tensor_list = []
        actions_tensor_list = []
        preferences_tensor_list = []
        reward_action_dists_tensor_list = []
        if self.cost_dim > 0:
            cost_action_dists_tensor_list = []
        for _ in range(self.n_prefer_samples):
            np_reward_preference = np.random.dirichlet(np.ones(self.preference_dim))
            np_reward_preferences = np.tile(np_reward_preference, (states_tensor.shape[0], 1))
            reward_preferences_tensor = torch.tensor(np_reward_preferences, dtype=torch.float32, device=self.device) # (n_steps, reward_dim)

            # for update temperatures and actor
            with torch.no_grad():
                sampled_actions_tensor_list = []
                sampled_reward_values_tensor_list = []
                if self.cost_dim > 0:
                    sampled_cost_values_tensor_list = []
                for _ in range(self.n_action_samples):
                    epsilons_tensor = torch.randn(actions_tensor.shape[:-1] + (self.action_dim,), device=self.device)
                    self.actor.updateActionDist(states_tensor, reward_preferences_tensor, epsilons_tensor)
                    sampled_actions_tensor = self.actor.sample(deterministic=False)[0]
                    sampled_actions_tensor_list.append(sampled_actions_tensor)
                    sampled_reward_values_tensor_list.append(self.reward_critic(states_tensor, sampled_actions_tensor, reward_preferences_tensor))
                    if self.cost_dim > 0:
                        sampled_cost_values_tensor_list.append(self.cost_critic(states_tensor, sampled_actions_tensor, reward_preferences_tensor))
                sampled_actions_tensor = torch.stack(sampled_actions_tensor_list, dim=0) # n_action_samples x n_steps x action_dim
                sampled_reward_values_tensor = torch.stack(sampled_reward_values_tensor_list, dim=0) # n_action_samples x n_steps x reward_dim
                if self.cost_dim > 0:
                    sampled_cost_values_tensor = torch.stack(sampled_cost_values_tensor_list, dim=0) # n_action_samples x n_steps x n_costs

            # ===================== update temperatures ===================== #
            bounds = [(EPS, None)]
            reward_preferences = self.max_reward_preference*np_reward_preference/np.max(np_reward_preference) # (reward_dim,)
            sampled_reward_values = sampled_reward_values_tensor.cpu().numpy()
            reward_temperatures = np.zeros(self.reward_dim)
            if self.cost_dim > 0:
                cost_preferences = self.cost_preferences.detach().cpu().numpy() # (n_costs,)
                sampled_cost_values = sampled_cost_values_tensor.cpu().numpy()
                cost_temperatures = np.zeros(self.cost_dim)
            for reward_idx in range(self.reward_dim):
                dual_fn = lambda eta: self.dual_problem(
                    eta, reward_preferences[reward_idx], sampled_reward_values[:, :, reward_idx])
                res = minimize(dual_fn, np.zeros(1), method='SLSQP', bounds=bounds)
                reward_temperatures[reward_idx] = res.x[0]
            for cost_idx in range(self.cost_dim):
                dual_fn = lambda eta: self.dual_problem(
                    eta, cost_preferences[cost_idx], -sampled_cost_values[:, :, cost_idx])
                res = minimize(dual_fn, np.zeros(1), method='SLSQP', bounds=bounds)
                cost_temperatures[cost_idx] = res.x[0]
            # =============================================================== #

            with torch.no_grad():
                reward_temperatures = torch.tensor(reward_temperatures, dtype=torch.float32, device=self.device)
                reward_action_dists_tensor = torch.softmax(sampled_reward_values_tensor/reward_temperatures, dim=0) # n_action_samples x n_steps x reward_dim
                if self.cost_dim > 0:
                    cost_temperatures = torch.tensor(cost_temperatures, dtype=torch.float32, device=self.device)
                    cost_action_dists_tensor = torch.softmax(-sampled_cost_values_tensor/cost_temperatures, dim=0) # n_action_samples x n_steps x cost_dim

            states_tensor_list.append(states_tensor) # n_steps x state_dim
            actions_tensor_list.append(sampled_actions_tensor) # n_action_samples x n_steps x action_dim
            preferences_tensor_list.append(reward_preferences_tensor) # n_steps x reward_dim
            reward_action_dists_tensor_list.append(reward_action_dists_tensor) # n_action_samples x n_steps x reward_dim
            if self.cost_dim > 0:
                cost_action_dists_tensor_list.append(cost_action_dists_tensor) # n_action_samples x n_steps x cost_dim

        states_tensor = torch.cat(states_tensor_list, dim=0) # n_steps x state_dim
        sampled_actions_tensor = torch.cat(actions_tensor_list, dim=1) # n_action_samples x n_steps x action_dim
        reward_preferences_tensor = torch.cat(preferences_tensor_list, dim=0) # n_steps x reward_dim
        reward_action_dists_tensor = torch.cat(reward_action_dists_tensor_list, dim=1) # n_action_samples x n_steps x reward_dim
        if self.cost_dim > 0:
            cost_action_dists_tensor = torch.cat(cost_action_dists_tensor_list, dim=1) # n_action_samples x n_steps x cost_dim

        # ======================== update actor ======================== #
        with torch.no_grad():
            old_means_tensor, _, old_stds_tensor = self.actor(states_tensor, reward_preferences_tensor)
            old_dists_tensor = torch.distributions.Normal(old_means_tensor, old_stds_tensor)
            entropy_tensor = torch.mean(torch.sum(old_dists_tensor.entropy(), dim=-1))

        def get_obj_kl():
            means_tensor, _, stds_tensor = self.actor(states_tensor, reward_preferences_tensor)
            dists_tensor = torch.distributions.Normal(means_tensor, stds_tensor)
            objective = 0.0
            for sampled_idx in range(self.n_action_samples):
                log_probs = dists_tensor.log_prob(sampled_actions_tensor[sampled_idx]).sum(-1) # n_steps
                nonparam_action_probs = reward_action_dists_tensor[sampled_idx].sum(-1) # n_steps
                if self.cost_dim > 0:
                    nonparam_action_probs += cost_action_dists_tensor[sampled_idx].sum(-1) # n_steps
                objective = objective + torch.mean(nonparam_action_probs*log_probs)
            objective = objective/self.n_action_samples
            kl = torch.mean(torch.sum(torch.distributions.kl.kl_divergence(old_dists_tensor, dists_tensor), dim=-1))
            return objective, kl

        # update actor
        objective, kl, max_kl, beta = self.actor_optimizer.step(get_obj_kl)
        # ============================================================== #

        # ====================== update preference ====================== #
        if self.cost_dim > 0:
            with torch.no_grad():
                constraints = torch.mean(self.cost_critic(states_tensor, sampled_actions_tensor.mean(dim=0), reward_preferences_tensor), dim=0)
                self.cost_preferences += self.preference_lr * (constraints - self.con_thresholds)
                self.cost_preferences = torch.clamp(self.cost_preferences, min=0.0, max=self.max_cost_preference)
        # =============================================================== #

        train_results = {
            'reward_critic_loss': reward_critic_loss.item(),
            'entropy': entropy_tensor.item(),
            'objective': objective,
            'max_kl': max_kl,
            'beta': beta,
            'kl': kl,
        }
        if self.cost_dim > 0:
            train_results['constraint'] = constraints.detach().cpu().numpy()
            train_results['cost_critic_loss'] = cost_critic_loss.item()
            train_results['cost_preferences'] = self.cost_preferences.detach().cpu().numpy()
        return train_results

    def save(self, model_num):
        # save rms
        self.obs_rms.save(self.save_dir, model_num)
        self.reward_rms.save(self.save_dir, model_num)

        # save network models
        save_dict = {
            'actor': self.actor.state_dict(),
            'reward_critic': self.reward_critic.state_dict(),
            'reward_critic_optimizer': self.reward_critic_optimizer.state_dict(),
        }
        if self.cost_dim > 0:
            save_dict['cost_critic'] = self.cost_critic.state_dict()
            save_dict['cost_critic_optimizer'] = self.cost_critic_optimizer.state_dict()
            save_dict['cost_preferences'] = self.cost_preferences.data
        torch.save(save_dict, f"{self.checkpoint_dir}/model_{model_num}.pt")
        cprint(f'[{self.name}] save success.', bold=True, color="blue")

    def load(self, model_num):
        # load rms
        self.obs_rms.load(self.save_dir, model_num)
        self.reward_rms.load(self.save_dir, model_num)

        # load network models
        if not os.path.isdir(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        checkpoint_file = f"{self.checkpoint_dir}/model_{model_num}.pt"
        if os.path.isfile(checkpoint_file):
            checkpoint = torch.load(checkpoint_file, map_location=self.device)
            self.actor.load_state_dict(checkpoint['actor'])
            self.reward_critic.load_state_dict(checkpoint['reward_critic'])
            self.reward_critic_optimizer.load_state_dict(checkpoint['reward_critic_optimizer'])
            if self.cost_dim > 0:
                self.cost_critic.load_state_dict(checkpoint['cost_critic'])
                self.cost_critic_optimizer.load_state_dict(checkpoint['cost_critic_optimizer'])
                self.cost_preferences.data = checkpoint['cost_preferences']
            cprint(f'[{self.name}] load success.', bold=True, color="blue")
            return int(model_num)
        else:
            self.actor.initialize()
            self.reward_critic.initialize()
            if self.cost_dim > 0:
                self.cost_critic.initialize()
            cprint(f'[{self.name}] load fail.', bold=True, color="red")
            return 0

