# yapf: disable
from collections import deque
import copy
import argparse

import numpy as np
import torch

from garage.torch.policies import TanhGaussianMLPPolicy
from garage.torch.q_functions import ContinuousMLPQFunction
from garage.envs import GymEnv

import random
import gym
import d4rl
# yapf: enable

DEVICE = torch.device('cuda')

torch.set_flush_denormal(True)

def normalized_sum(loss, reg, w):
    return loss/w + reg if w>1 else loss + w*reg

def l2_projection(constraint):
    @torch.no_grad()
    def fn(module):
        if hasattr(module, 'weight') and constraint>0:
            w = module.weight
            norm = torch.norm(w)
            w.mul_(torch.clip(constraint/norm, max=1))
    return fn

class ATAC():
    """ Adversarilly Trained Actor Critic """
    def __init__(
            self,
            env_spec,
            policy,
            qf1,
            qf2,
            dataset,
            *,  # Everything after this is numbers.
            max_episode_length_eval=None,
            gradient_steps_per_itr=2000,
            fixed_alpha=None,
            target_entropy=None,
            initial_log_entropy=0.,
            discount=0.99,
            buffer_batch_size=256,
            min_buffer_size=int(1e4),
            target_update_tau=5e-3,
            policy_lr=5e-7,
            qf_lr=5e-4,
            reward_scale=1.0,
            optimizer='Adam',
            steps_per_epoch=1,
            num_evaluation_episodes=10,
            eval_env=None,
            use_deterministic_evaluation=True,
            # ATAC parameters
            beta=1.0,  # the regularization coefficient in front of the Bellman error
            lambd=0., # coeff for global pessimism
            init_observations=None, # for ATAC0 (None or np.ndarray)
            n_warmstart_steps=200000,
            norm_constraint=100,
            q_eval_mode='0.5_0.5', # 'max' 'w1_w2', 'adaptive'
            q_eval_loss='MSELoss', # 'MSELoss', 'SmoothL1Loss'
            use_two_qfs=False,
            terminal_value=None,
            Vmin=-float('inf'), # min value of Q (used in target backup)
            Vmax=float('inf'), # max value of Q (used in target backup)
            debug=False,
            stats_avg_rate=0.99, # for logging
            bellman_surrogate='td', #'td', None, 'target'
            ):

        #############################################################################################

        assert beta>=0
        assert norm_constraint>=0
        # Parsing
        optimizer = eval('torch.optim.'+optimizer)
        policy_lr = qf_lr if policy_lr is None or policy_lr < 0 else policy_lr # use shared lr if not provided.
        self.dataset = dataset
        ## ATAC parameters
        self.beta = torch.Tensor([beta]) # regularization constant on the Bellman surrogate
        self._lambd = torch.Tensor([lambd])  # global pessimism coefficient
        self._init_observations = torch.Tensor(init_observations) if init_observations is not None else init_observations  # if provided, it runs ATAC0
        self._n_warmstart_steps = n_warmstart_steps  # during which, it performs independent C and Bellman minimization
        # q update parameters
        self._norm_constraint = norm_constraint  # l2 norm constraint on the qf's weight; if negative, it gives the weight decay coefficient.
        self._q_eval_mode = [float(w) for w in q_eval_mode.split('_')] if '_' in q_eval_mode else  q_eval_mode  # residual algorithm
        self._q_eval_loss = eval('torch.nn.'+q_eval_loss)(reduction='none')
        self._use_two_qfs = use_two_qfs
        self._Vmin = Vmin  # lower bound on the target
        self._Vmax = Vmax  # upper bound on the target
        self._terminal_value = terminal_value if terminal_value is not None else lambda r, gamma: 0.

        # Stepsizes
        self._alpha_lr = qf_lr # potentially a larger stepsize, for the most inner optimization.
        self._bc_policy_lr = qf_lr  # potentially a larger stepsize

        # Logging and algorithm state
        self._debug = debug
        self._n_updates_performed = 0 # Counter of number of grad steps performed
        self._cac_learning=False
        self._stats_avg_rate = stats_avg_rate
        self._bellman_surrogate = bellman_surrogate
        self._avg_bellman_error = 1.  # for logging; so this works with zero warm-start
        self._avg_terminal_td_error = 1

        #############################################################################################
        # Original SAC parameters
        self._qf1 = qf1
        self._qf2 = qf2
        self._tau = target_update_tau
        self._policy_lr = policy_lr
        self._qf_lr = qf_lr
        self._initial_log_entropy = initial_log_entropy
        self._gradient_steps = gradient_steps_per_itr
        self._optimizer = optimizer
        self._num_evaluation_episodes = num_evaluation_episodes
        self._eval_env = eval_env

        self._min_buffer_size = min_buffer_size
        self._steps_per_epoch = steps_per_epoch
        self._buffer_batch_size = buffer_batch_size
        self._discount = discount
        self._reward_scale = reward_scale
        self.max_episode_length = 1000
        self._max_episode_length_eval = 1000

        if max_episode_length_eval is not None:
            self._max_episode_length_eval = max_episode_length_eval
        self._use_deterministic_evaluation = use_deterministic_evaluation

        self.policy = policy
        self.env_spec = env_spec

        # use 2 target q networks
        self._target_qf1 = copy.deepcopy(self._qf1)
        self._target_qf2 = copy.deepcopy(self._qf2)
        self._policy_optimizer = self._optimizer(self.policy.parameters(),
                                                 lr=self._bc_policy_lr) #  lr for warmstart
        self._qf1_optimizer = self._optimizer(self._qf1.parameters(),
                                              lr=self._qf_lr)
        self._qf2_optimizer = self._optimizer(self._qf2.parameters(),
                                              lr=self._qf_lr)

        self._qf1.to(DEVICE)
        self._qf2.to(DEVICE)
        self._target_qf1.to(DEVICE)
        self._target_qf2.to(DEVICE)
        self.policy.to(DEVICE)


        # automatic entropy coefficient tuning
        self._use_automatic_entropy_tuning = fixed_alpha is None
        self._fixed_alpha = fixed_alpha
        if self._use_automatic_entropy_tuning:
            if target_entropy:
                self._target_entropy = target_entropy
            else:
                self._target_entropy = -np.prod(
                    self.env_spec.action_space.shape).item()
            self._log_alpha = torch.Tensor([self._initial_log_entropy
                                            ]).requires_grad_()
            self._alpha_optimizer = optimizer([self._log_alpha],
                                              lr=self._alpha_lr)
        else:
            self._log_alpha = torch.Tensor([self._fixed_alpha]).log()

        self._log_alpha = torch.tensor(self._log_alpha).to(DEVICE)
        self._target_entropy = torch.tensor(self._target_entropy).to(DEVICE)

        self.episode_rewards = deque(maxlen=30)


    def optimize_policy(self,
                        batch_idx,
                        warmstart=False):
        """Optimize the policy q_functions, and temperature coefficient.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """

        factor = 1

        qf1_loss, qf2_loss, policy_loss, alpha_loss = 0, 0, 0, 0
        for idx in batch_idx:
            traj_obs1, traj_obs2, traj_act1, traj_act2, pref = self.dataset.sample(idx)

            traj_obs1, traj_act1, traj_obs2, traj_act2 = torch.as_tensor(traj_obs1,
                                                                         dtype=torch.float32).clone().detach().to(
                DEVICE), \
                torch.as_tensor(traj_act1, dtype=torch.float32).clone().detach().to(DEVICE), \
                torch.as_tensor(traj_obs2, dtype=torch.float32).clone().detach().to(DEVICE), \
                torch.as_tensor(traj_act2, dtype=torch.float32).clone().detach().to(DEVICE)

            def compute_utility(traj_obs, traj_act, q):
                q_vals = factor * q(traj_obs[:-1], traj_act)
                new_next_actions_dist = self.policy(traj_obs[1:])[0]
                _, new_next_actions = new_next_actions_dist.rsample_with_pre_tanh_value()
                q_next_vals = factor * q(traj_obs[1:], new_next_actions)
                return sum(q_vals - self._discount * q_next_vals)

            def compute_utility_loss(traj_obs1, traj_act1, traj_obs2, traj_act2, pref, q):
                utility_1 = compute_utility(traj_obs1, traj_act1, q)
                utility_2 = compute_utility(traj_obs2, traj_act2, q)
                if pref == 1:
                    loss = - torch.log(torch.exp(utility_1) / (torch.exp(utility_1) + torch.exp(utility_2)))
                if pref == 2:
                    loss = - torch.log(torch.exp(utility_2) / (torch.exp(utility_1) + torch.exp(utility_2)))
                return loss

            qf1_utility_loss = compute_utility_loss(traj_obs1, traj_act1, traj_obs2, traj_act2, pref, self._qf1)
            qf2_utility_loss = compute_utility_loss(traj_obs1, traj_act1, traj_obs2, traj_act2, pref, self._qf2)

            traj_obs = torch.cat((traj_obs1[:-1],traj_obs2[:-1]))
            traj_act = torch.cat((traj_act1,traj_act2))
            new_actions_dist = self.policy(traj_obs)[0]
            new_actions_pre_tanh, new_actions = new_actions_dist.rsample_with_pre_tanh_value()

            qf1_pred = factor * self._qf1(traj_obs, traj_act).flatten()
            qf2_pred = factor * self._qf2(traj_obs, traj_act).flatten()

            gan_qf1_loss = gan_qf2_loss = 0
            if not warmstart:  # Compute gan_qf1_loss, gan_qf2_loss
                if self._init_observations is None:
                    # Compute value difference
                    qf1_new_actions = factor * self._qf1(traj_obs, new_actions.detach())
                    gan_qf1_loss = (qf1_new_actions*(1+self._lambd) - qf1_pred).mean()
                    if self._use_two_qfs:
                        qf2_new_actions = factor * self._qf2(traj_obs, new_actions.detach())
                        gan_qf2_loss = (qf2_new_actions*(1+self._lambd) - qf2_pred).mean()
                else: # initial state pessimism
                    idx_ = np.random.choice(len(self._init_observations), self._buffer_batch_size)
                    init_observations = self._init_observations[idx_]
                    init_actions_dist = self.policy(init_observations)[0]
                    init_actions_pre_tanh, init_actions = init_actions_dist.rsample_with_pre_tanh_value()
                    qf1_new_actions = factor * self._qf1(init_observations, init_actions.detach())
                    gan_qf1_loss = qf1_new_actions.mean()
                    if self._use_two_qfs:
                        qf2_new_actions = factor * self._qf2(init_observations, init_actions.detach())
                        gan_qf2_loss = qf2_new_actions.mean()

            with torch.no_grad():
                beta = self.beta
                beta = torch.tensor(beta).to(DEVICE)

            gan_qf1_loss = torch.tensor(gan_qf1_loss).to(DEVICE)
            gan_qf2_loss = torch.tensor(gan_qf2_loss).to(DEVICE)
            qf1_loss += normalized_sum(gan_qf1_loss, qf1_utility_loss, beta)
            qf2_loss += normalized_sum(gan_qf2_loss, qf2_utility_loss, beta)


            ##### Update Actor #####

            # Compuate entropy
            log_pi_new_actions = new_actions_dist.log_prob(value=new_actions, pre_tanh_value=new_actions_pre_tanh)
            policy_entropy = -log_pi_new_actions.mean()

            alpha_loss = 0
            if self._use_automatic_entropy_tuning:  # it comes first; seems to work also when put after policy update
                alpha_loss += self._log_alpha * (policy_entropy - self._target_entropy)  # entropy - target

            with torch.no_grad():
                alpha = self._log_alpha.exp()

            lower_bound = 0
            if warmstart: # BC warmstart
                policy_log_prob = new_actions_dist.log_prob(traj_act)
                # policy_loss = - policy_log_prob.mean() - alpha * policy_entropy
                policy_loss += normalized_sum(-policy_log_prob.mean(), -policy_entropy, alpha)
            else:
                # Compute performance difference lower bound
                min_q_new_actions = self._qf1(traj_obs, new_actions)
                lower_bound = min_q_new_actions.mean()
                # policy_loss = - lower_bound - alpha * policy_kl
                policy_loss += normalized_sum(-lower_bound, -policy_entropy, alpha)

        if beta>0 or not warmstart:
            self._qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self._qf1_optimizer.step()
            self._qf1.apply(l2_projection(self._norm_constraint))

            if self._use_two_qfs:
                self._qf2_optimizer.zero_grad()
                qf2_loss.backward()
                self._qf2_optimizer.step()
                self._qf2.apply(l2_projection(self._norm_constraint))

        self._alpha_optimizer.zero_grad()
        alpha_loss.backward(retain_graph=True)
        self._alpha_optimizer.step()

        self._policy_optimizer.zero_grad()
        policy_loss.backward()
        self._policy_optimizer.step()


    # Below is overwritten for general logging with log_info
    def train(self, num_epoch, step_per_epoch, batch_size=64):
        for epoch in range(num_epoch):
            for step in range(step_per_epoch):
                batch_idx = random.sample(range(dataset.num), batch_size)
                self.train_once(batch_idx)
            test_ep_rets = self.test()
            print('epoch:', epoch, 'pfm:', sum(test_ep_rets)/len(test_ep_rets))

    def train_once(self, batch_idx):
        """Complete 1 training iteration of ATAC.

        Args:
            itr (int): Iteration number. This argument is deprecated.
            paths (list[dict]): A list of collected paths.
                This argument is deprecated.

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """

        warmstart = self._n_updates_performed<self._n_warmstart_steps
        if not warmstart and not self._cac_learning:
            self._cac_learning = True
            # Reset optimizers since the objective changes
            if self._use_automatic_entropy_tuning:
                self._log_alpha = torch.Tensor([self._initial_log_entropy]).to(self._log_alpha.device).requires_grad_()
                self._alpha_optimizer = self._optimizer([self._log_alpha], lr=self._alpha_lr)
            self._policy_optimizer = self._optimizer(self.policy.parameters(), lr=self._policy_lr)

        self.optimize_policy(batch_idx, warmstart=warmstart)
        self._update_targets()
        self._n_updates_performed += 1

    # Update also the target policy if needed
    def _update_targets(self):
        """Update parameters in the target q-functions."""
        if self._use_two_qfs:
            target_qfs = [self._target_qf1, self._target_qf2]
            qfs = [self._qf1, self._qf2]
        else:
            target_qfs = [self._target_qf1]
            qfs = [self._qf1]

        for target_qf, qf in zip(target_qfs, qfs):
            for t_param, param in zip(target_qf.parameters(), qf.parameters()):
                t_param.data.copy_(t_param.data * (1.0 - self._tau) + param.data * self._tau)

    def test(self):
        with torch.no_grad():
            ep_rets = []
            for j in range(3):
                o, d, ep_ret, ep_len = self.env_spec.reset(), False, 0, 0
                while not (d or (ep_len == self.max_episode_length)):
                    # Take deterministic actions at test time
                    act_dist = self.policy(torch.as_tensor(o, dtype=torch.float32).clone().detach().to(DEVICE))[0]
                    _, act = act_dist.rsample_with_pre_tanh_value()
                    o, r, d, _ = self.env_spec.step(act.cpu().detach().numpy())
                    ep_ret += r
                    ep_len += 1
                ep_rets.append(ep_ret)
            return ep_rets

class Dataset():
    def __init__(self, env, traj, sample, pref, data):
        traj_file = np.load(
            '../dataset/%s_%s_%s_%s_%d.npz' % (env, traj, sample, pref, data))
        self.traj_obs = traj_file['traj_obs']
        self.traj_act = traj_file['traj_act']
        self.traj_rew = traj_file['traj_rew']
        self.traj_idx_1 = traj_file['traj_idx_1']
        self.traj_idx_2 = traj_file['traj_idx_2']
        self.pref = traj_file['pref']
        self.num = len(self.pref)
    def sample(self, idx):
        idx1 = self.traj_idx_1[idx]
        idx2 = self.traj_idx_2[idx]
        return self.traj_obs[idx1], self.traj_obs[idx2], self.traj_act[idx1], self.traj_act[idx2], self.pref[idx]

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Hopper-v3')
    parser.add_argument("--traj", type=str, default='medium')
    parser.add_argument('--sample', type=str, default='uniform')
    parser.add_argument('--pref', type=str, default='regular')
    parser.add_argument('--data', type=int, default=100000)

    args = parser.parse_args()

    policy_hidden_sizes = (64, 64, 64)
    policy_activation = 'ReLU'
    policy_init_std = 1.0
    value_hidden_sizes = (64, 64, 64)
    value_activation = 'ReLU'
    min_std = 1e-5
    value_output_activation = 'Sigmoid'

    env = gym.make(args.env)
    policy = TanhGaussianMLPPolicy(
                env_spec=GymEnv(env).spec,
                hidden_sizes=policy_hidden_sizes,
                hidden_nonlinearity=eval('torch.nn.'+policy_activation),
                init_std=policy_init_std,
                min_std=min_std)

    qf1 = ContinuousMLPQFunction(
                env_spec=GymEnv(env).spec,
                hidden_sizes=value_hidden_sizes,
                hidden_nonlinearity=eval('torch.nn.'+value_activation),
                output_nonlinearity=None)
    # output_nonlinearity = eval('torch.nn.' + value_output_activation)
    qf2 = ContinuousMLPQFunction(
                env_spec=GymEnv(env).spec,
                hidden_sizes=value_hidden_sizes,
                hidden_nonlinearity=eval('torch.nn.'+value_activation),
                output_nonlinearity=None)

    dataset = Dataset(args.env, args.traj, args.sample, args.pref, args.data)

    trainer = ATAC(env, policy, qf1, qf2, dataset)
    trainer.train(300,30)