from copy import deepcopy
import numpy as np
import torch
from torch.optim import Adam
from numpy import linalg as LA
import gym
import argparse
import json
from utils import redirect_stdout
import torch.nn as nn
import itertools
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
from spinup_copy import mpi_avg, mpi_statistics_scalar, num_procs, setup_pytorch_for_mpi, sync_params, mpi_avg_grads
import scipy.signal
from gym.spaces import Box, Discrete
import os, time
import random

def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

class ReplayBuffer:

    def __init__(self, obs_dim, act_dim, size):
        self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.obs2_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(combined_shape(size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(obs=self.obs_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     act=self.act_buf[idxs],
                     rew=self.rew_buf[idxs],
                     done=self.done_buf[idxs])
        return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in batch.items()}

def discount_cumsum(x, discount):
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers).to(torch.device('cuda'))

class DDPGActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim]
        self.pi = mlp(pi_sizes, activation, nn.Tanh)
        self.act_limit = (torch.as_tensor(act_limit, dtype=torch.float32)).to(torch.device('cuda'))

    def forward(self, obs):
        return self.act_limit * self.pi(obs)

class DDPGQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

class DDPGActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, hidden_sizes=(256,256),
                 activation=nn.ReLU):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        # build policy and value functions
        self.pi = DDPGActor(obs_dim, act_dim, hidden_sizes, activation, act_limit)
        self.q = DDPGQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.device = torch.device('cuda')

    def act(self, obs):
        with torch.no_grad():
            return self.pi(obs).cpu().numpy()

class DDPG(object):
    def __init__(self, env_name, ac_kwargs=dict(), replay_size=int(1e6), gamma=0.99, polyak=0.995, pi_lr=1e-3,
                 q_lr=1e-3, batch_size=100, act_noise=0.1, num_test_episodes=10, max_ep_len=1000):
        self.name = 'ddpg'
        self.gamma = gamma
        self.polyak = polyak
        self.batch_size = batch_size
        self.act_noise = act_noise
        self.num_test_episodes = num_test_episodes
        self.max_ep_len = max_ep_len
        self.ac_kwargs = ac_kwargs
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.test_env = gym.make(env_name)
        self.ac = DDPGActorCritic(self.env.observation_space, self.env.action_space, **ac_kwargs)
        self.ac_targ = deepcopy(self.ac)
        for p in self.ac_targ.parameters():
            p.requires_grad = False
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=pi_lr)
        self.q_optimizer = Adam(self.ac.q.parameters(), lr=q_lr)
        self.pi_lr = pi_lr
        self.q_lr = q_lr

        self.obs_dim = self.env.observation_space.shape
        self.act_dim = self.env.action_space.shape[0]
        self.act_limit = self.env.action_space.high[0]
        self.replay_buffer = ReplayBuffer(obs_dim=self.obs_dim, act_dim=self.act_dim, size=replay_size)
        self.replay_size = replay_size

    def compute_loss_q(self, data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']
        o, a, r, o2, d = o.clone().detach().to(self.ac.device), a.clone().detach().to(self.ac.device), r.clone().detach().to(self.ac.device), \
                         o2.clone().detach().to(self.ac.device), d.clone().detach().to(self.ac.device)

        q = self.ac.q(o, a)

        # Bellman backup for Q function
        with torch.no_grad():
            q_pi_targ = self.ac_targ.q(o2, self.ac_targ.pi(o2))
            backup = r + self.gamma * (1 - d) * q_pi_targ

        # MSE loss against Bellman backup
        loss_q = ((q - backup) ** 2).mean()

        return loss_q

    def compute_loss_pi(self, data):
        o = data['obs']
        o = o.clone().detach().to(self.ac.device)
        q_pi = self.ac.q(o, self.ac.pi(o))
        return -q_pi.mean()

    def update(self, data):
        # First run one gradient descent step for Q.
        self.q_optimizer.zero_grad()
        loss_q = self.compute_loss_q(data)
        loss_q.to(self.ac.device)
        loss_q.backward()
        self.q_optimizer.step()

        # Freeze Q-network so you don't waste computational effort
        # computing gradients for it during the policy learning step.
        for p in self.ac.q.parameters():
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        self.pi_optimizer.zero_grad()
        loss_pi = self.compute_loss_pi(data)
        loss_pi.to(self.ac.device)
        loss_pi.backward()
        self.pi_optimizer.step()

        # Unfreeze Q-network so you can optimize it at next DDPG step.
        for p in self.ac.q.parameters():
            p.requires_grad = True


        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(self.polyak)
                p_targ.data.add_((1 - self.polyak) * p.data)

    def get_action(self, o):
        noise_scale = self.act_noise
        a = self.ac.act(torch.as_tensor(o, dtype=torch.float32).to(self.ac.device))
        a += noise_scale * np.random.randn(self.act_dim)
        return np.clip(a, -self.act_limit, self.act_limit)

    def get_action_test(self, o):
        a = self.ac.act(torch.as_tensor(o, dtype=torch.float32).to(self.ac.device))
        return np.clip(a, -self.act_limit, self.act_limit)

    def test_agent(self):
        ep_rets = []
        for j in range(self.num_test_episodes):
            o, d, ep_ret, ep_len = self.test_env.reset(), False, 0, 0
            while not (d or (ep_len == self.max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, _ = self.test_env.step(self.get_action_test(o))
                ep_ret += r
                ep_len += 1
            ep_rets.append(ep_ret)
        return ep_rets

    def test_target(self, behavior_metric):
        for j in range(self.num_test_episodes):
            o, d, target_ret, true_ret, ep_len = self.test_env.reset(), False, 0, 0, 0
            while not (d or (ep_len == self.max_ep_len)):
                a = self.get_action_test(o)
                o, r, d, _ = self.test_env.step(a)
                ep_len += 1
                true_ret += r
                target_ret += behavior_metric(o)
        return true_ret/self.num_test_episodes, target_ret/self.num_test_episodes

    def reset(self):
        self.ac = DDPGActorCritic(self.env.observation_space, self.env.action_space, **self.ac_kwargs)
        self.ac_targ = deepcopy(self.ac)
        for p in self.ac_targ.parameters():
            p.requires_grad = False
        self.replay_buffer = ReplayBuffer(obs_dim=self.obs_dim, act_dim=self.act_dim, size=self.replay_size)
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=self.pi_lr)
        self.q_optimizer = Adam(self.ac.q.parameters(), lr=self.q_lr)

class TD3Actor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim]
        self.pi = mlp(pi_sizes, activation, nn.Tanh)
        self.act_limit = (torch.as_tensor(act_limit, dtype=torch.float32)).to(torch.device('cuda'))

    def forward(self, obs):
        return self.act_limit * self.pi(obs)

class TD3QFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

class TD3ActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, hidden_sizes=(256,256),
                 activation=nn.ReLU):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        # build policy and value functions
        self.pi = TD3Actor(obs_dim, act_dim, hidden_sizes, activation, act_limit)
        self.q1 = TD3QFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = TD3QFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.device = torch.device('cuda')

    def act(self, obs):
        with torch.no_grad():
            return self.pi(obs).cpu().numpy()

class TD3(object):
    def __init__(self, env_name, ac_kwargs=dict(), replay_size=int(1e6), gamma=0.99,
        polyak=0.995, pi_lr=1e-3, q_lr=1e-3, batch_size=100, act_noise=0.1, target_noise=0.2,
        noise_clip=0.5, policy_delay=2, num_test_episodes=10, max_ep_len=1000):
        self.name = 'td3'
        self.gamma = gamma
        self.polyak = polyak
        self.batch_size = batch_size
        self.act_noise = act_noise
        self.target_noise = target_noise
        self.noise_clip = noise_clip
        self.policy_delay = policy_delay
        self.num_test_episodes = num_test_episodes
        self.max_ep_len = max_ep_len
        self.ac_kwargs = ac_kwargs
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.test_env = gym.make(env_name)
        self.ac = TD3ActorCritic(self.env.observation_space, self.env.action_space, **ac_kwargs)
        self.ac_targ = deepcopy(self.ac)
        for p in self.ac_targ.parameters():
            p.requires_grad = False
        self.q_params = itertools.chain(self.ac.q1.parameters(), self.ac.q2.parameters())
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=pi_lr)
        self.q_optimizer = Adam(self.q_params, lr=q_lr)
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.obs_dim = self.env.observation_space.shape
        self.act_dim = self.env.action_space.shape[0]
        self.act_limit = self.env.action_space.high[0]
        self.replay_buffer = ReplayBuffer(obs_dim=self.obs_dim, act_dim=self.act_dim, size=replay_size)
        self.replay_size = replay_size

    def compute_loss_q(self, data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']
        o, a, r, o2, d = o.clone().detach().to(self.ac.device), a.clone().detach().to(self.ac.device), r.clone().detach().to(self.ac.device), \
                         o2.clone().detach().to(self.ac.device), d.clone().detach().to(self.ac.device)

        q1 = self.ac.q1(o, a)
        q2 = self.ac.q2(o, a)

        # Bellman backup for Q functions
        with torch.no_grad():
            pi_targ = self.ac_targ.pi(o2)

            # Target policy smoothing
            epsilon = torch.randn_like(pi_targ) * self.target_noise
            epsilon = torch.clamp(epsilon, -self.noise_clip, self.noise_clip)
            a2 = pi_targ + epsilon
            a2 = torch.clamp(a2, -self.act_limit, self.act_limit)

            # Target Q-values
            q1_pi_targ = self.ac_targ.q1(o2, a2)
            q2_pi_targ = self.ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + self.gamma * (1 - d) * q_pi_targ

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup) ** 2).mean()
        loss_q2 = ((q2 - backup) ** 2).mean()
        loss_q = loss_q1 + loss_q2
        return loss_q

    def compute_loss_pi(self, data):
        o = data['obs']
        o = o.clone().detach().to(self.ac.device)
        q1_pi = self.ac.q1(o, self.ac.pi(o))
        return -q1_pi.mean()

    def update(self, data, timer):
        # First run one gradient descent step for Q1 and Q2
        self.q_optimizer.zero_grad()
        loss_q = self.compute_loss_q(data)
        loss_q.to(self.ac.device)
        loss_q.backward()
        self.q_optimizer.step()

        # Possibly update pi and target networks
        if timer % self.policy_delay == 0:

            # Freeze Q-networks so you don't waste computational effort
            # computing gradients for them during the policy learning step.
            for p in self.q_params:
                p.requires_grad = False

            # Next run one gradient descent step for pi.
            self.pi_optimizer.zero_grad()
            loss_pi = self.compute_loss_pi(data)
            loss_pi.to(self.ac.device)
            loss_pi.backward()
            self.pi_optimizer.step()

            for p in self.q_params:
                p.requires_grad = True

            # Finally, update target networks by polyak averaging.
            with torch.no_grad():
                for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                    # NB: We use an in-place operations "mul_", "add_" to update target
                    # params, as opposed to "mul" and "add", which would make new tensors.
                    p_targ.data.mul_(self.polyak)
                    p_targ.data.add_((1 - self.polyak) * p.data)

    def get_action(self, o):
        noise_scale = self.act_noise
        a = self.ac.act(torch.as_tensor(o, dtype=torch.float32).to(self.ac.device))
        a += noise_scale * np.random.randn(self.act_dim)
        return np.clip(a, -self.act_limit, self.act_limit)

    def get_action_test(self, o):
        a = self.ac.act(torch.as_tensor(o, dtype=torch.float32).to(self.ac.device))
        return np.clip(a, -self.act_limit, self.act_limit)

    def test_agent(self):
        ep_rets = []
        for j in range(self.num_test_episodes):
            o, d, ep_ret, ep_len = self.test_env.reset(), False, 0, 0
            while not (d or (ep_len == self.max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, _ = self.test_env.step(self.get_action_test(o))
                ep_ret += r
                ep_len += 1
            ep_rets.append(ep_ret)
        return ep_rets

    def test_target(self, behavior_metric):
        target_ret, true_ret = 0, 0
        for j in range(self.num_test_episodes):
            o, d, ep_len = self.test_env.reset(), False, 0
            while not (d or (ep_len == self.max_ep_len)):
                a = self.get_action_test(o)
                o, r, d, _ = self.test_env.step(a)
                ep_len += 1
                true_ret += r
                target_ret += behavior_metric(o)
        return true_ret/self.num_test_episodes, target_ret/self.num_test_episodes

    def reset(self):
        self.ac = TD3ActorCritic(self.env.observation_space, self.env.action_space, **self.ac_kwargs)
        self.ac_targ = deepcopy(self.ac)
        for p in self.ac_targ.parameters():
            p.requires_grad = False
        self.q_params = itertools.chain(self.ac.q1.parameters(), self.ac.q2.parameters())
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=self.pi_lr)
        self.q_optimizer = Adam(self.q_params, lr=self.q_lr)
        self.replay_buffer = ReplayBuffer(obs_dim=self.obs_dim, act_dim=self.act_dim, size=self.replay_size)

LOG_STD_MAX = 2
LOG_STD_MIN = -20

class SACActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim).to(torch.device('cuda'))
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim).to(torch.device('cuda'))
        self.act_limit = (torch.as_tensor(act_limit, dtype=torch.float32)).to(torch.device('cuda'))

    def forward(self, obs, deterministic=False, with_logprob=True):
        obs = obs.clone().detach().to(torch.device('cuda'))
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290)
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None
        if logp_pi != None:
            logp_pi.to(torch.device('cuda'))
        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi

class SACQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

class SACActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, hidden_sizes=(256,256),
                 activation=nn.ReLU):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        # build policy and value functions
        self.pi = SACActor(obs_dim, act_dim, hidden_sizes, activation, act_limit)
        self.q1 = SACQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = SACQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.device = torch.device('cuda')

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.cpu().numpy()

class SAC(object):
    def __init__(self, env_name, ac_kwargs=dict(), replay_size=int(1e6), gamma=0.99,
        polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, num_test_episodes=10, max_ep_len=1000):
        self.name = 'sac'
        self.gamma = gamma
        self.polyak = polyak
        self.batch_size = batch_size
        self.lr =lr
        self.alpha = alpha
        self.num_test_episodes = num_test_episodes
        self.max_ep_len = max_ep_len
        self.ac_kwargs = ac_kwargs
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.test_env = gym.make(env_name)
        self.ac = SACActorCritic(self.env.observation_space, self.env.action_space, **ac_kwargs)
        self.ac_targ = deepcopy(self.ac)
        for p in self.ac_targ.parameters():
            p.requires_grad = False
        self.q_params = itertools.chain(self.ac.q1.parameters(), self.ac.q2.parameters())
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=self.lr)
        self.q_optimizer = Adam(self.q_params, lr=self.lr)
        self.obs_dim = self.env.observation_space.shape
        self.act_dim = self.env.action_space.shape[0]
        self.act_limit = self.env.action_space.high[0]
        self.replay_buffer = ReplayBuffer(obs_dim=self.obs_dim, act_dim=self.act_dim, size=replay_size)
        self.replay_size = replay_size

    def compute_loss_q(self, data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']
        o, a, r, o2, d = o.clone().detach().to(self.ac.device), a.clone().detach().to(self.ac.device), r.clone().detach().to(self.ac.device), \
                         o2.clone().detach().to(self.ac.device), d.clone().detach().to(self.ac.device)
        q1 = self.ac.q1(o, a)
        q2 = self.ac.q2(o, a)

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2 = self.ac.pi(o2)

            # Target Q-values
            q1_pi_targ = self.ac_targ.q1(o2, a2)
            q2_pi_targ = self.ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + self.gamma * (1 - d) * (q_pi_targ - self.alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup) ** 2).mean()
        loss_q2 = ((q2 - backup) ** 2).mean()
        loss_q = loss_q1 + loss_q2

        return loss_q

    # Set up function for computing SAC pi loss
    def compute_loss_pi(self, data):
        o = data['obs']
        o = o.clone().detach().to(self.ac.device)
        pi, logp_pi = self.ac.pi(o)
        q1_pi = self.ac.q1(o, pi)
        q2_pi = self.ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (self.alpha * logp_pi - q_pi).mean()

        return loss_pi

    def update(self, data):
        # First run one gradient descent step for Q1 and Q2
        self.q_optimizer.zero_grad()
        loss_q = self.compute_loss_q(data)
        loss_q.to(self.ac.device)
        loss_q.backward()
        self.q_optimizer.step()

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        for p in self.q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        self.pi_optimizer.zero_grad()
        loss_pi = self.compute_loss_pi(data)
        loss_pi.to(self.ac.device)
        loss_pi.backward()
        self.pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in self.q_params:
            p.requires_grad = True

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(self.polyak)
                p_targ.data.add_((1 - self.polyak) * p.data)

    def get_action(self, o):
        return self.ac.act(torch.as_tensor(o, dtype=torch.float32), False)

    def get_action_test(self, o):
        return self.ac.act(torch.as_tensor(o, dtype=torch.float32), True)

    def test_agent(self):
        ep_rets = []
        for j in range(self.num_test_episodes):
            o, d, ep_ret, ep_len = self.test_env.reset(), False, 0, 0
            while not (d or (ep_len == self.max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, _ = self.test_env.step(self.get_action_test(o))
                ep_ret += r
                ep_len += 1
            ep_rets.append(ep_ret)
        return ep_rets

    def test_target(self, behavior_metric):
        for j in range(self.num_test_episodes):
            o, d, target_ret, true_ret, ep_len = self.test_env.reset(), False, 0, 0, 0
            while not (d or (ep_len == self.max_ep_len)):
                a = self.get_action_test(o)
                o, r, d, _ = self.test_env.step(a)
                ep_len += 1
                true_ret += r
                target_ret += behavior_metric(o)
        return true_ret/self.num_test_episodes, target_ret/self.num_test_episodes

    def reset(self):
        self.ac = SACActorCritic(self.env.observation_space, self.env.action_space, **self.ac_kwargs)
        self.ac_targ = deepcopy(self.ac)
        for p in self.ac_targ.parameters():
            p.requires_grad = False
        self.q_params = itertools.chain(self.ac.q1.parameters(), self.ac.q2.parameters())
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=self.lr)
        self.q_optimizer = Adam(self.q_params, lr=self.lr)
        self.replay_buffer = ReplayBuffer(obs_dim=self.obs_dim, act_dim=self.act_dim, size=self.replay_size)

def training_target_atk(agent, dir, atk_agent, atk_params,
         steps_per_epoch=4000, epochs=100, batch_size=100, start_steps=10000,
         update_after=1000, update_every=50,max_ep_len=1000, n_runs=1, switch=False, rad=None):

    total_steps = steps_per_epoch * epochs

    [C, B] = atk_params
    C = C * total_steps * B
    distance_log = []
    budget_spend_log = []
    high_end = 2 * ((agent.act_dim)**0.5)
    attack_power = lambda x: min(B, B*x/high_end)
    min_distance = float('inf')
    converge = 0
    for i in range(n_runs):
        print('n_run:', i)
        budget_spends, distances = [], []
        o, ep_ret, ep_len, ep_ret_clean = agent.env.reset(), 0, 0, 0
        distance = 0
        distance_epoch = []
        budget_spend = 0
        for t in range(total_steps):
            if t > start_steps:
                a = agent.get_action(o)
            else:
                a = agent.env.action_space.sample()

            o2, r, d, _ = agent.env.step(a)

            ep_ret_clean += r

            a_atk = atk_agent.get_action_test(o)

            distance += LA.norm(a - a_atk)
            distance_epoch.append(LA.norm(a - a_atk))
            if atk_agent.name == 'ppo':
                a_atk = np.clip(a_atk, -agent.act_limit, agent.act_limit)

            if C > 0:
                if rad is None:
                    perturb = attack_power(LA.norm(a - a_atk))
                    r -= perturb
                    C -= abs(perturb)
                    budget_spend += abs(perturb)
                elif rad > 0 and rad < 10:
                    if distance > rad:
                        perturb = B
                        r -= perturb
                        C -= abs(perturb)
                        budget_spend += abs(perturb)
                else:
                    perturb = B * random.uniform(-1, 1)
                    r -= perturb
                    C -= abs(perturb)
                    budget_spend += abs(perturb)

            ep_ret += r
            ep_len += 1

            d = False if ep_len == max_ep_len else d

            agent.replay_buffer.store(o, a, r, o2, d)

            o = o2

            # End of trajectory handling
            if d or (ep_len == max_ep_len):
                o, ep_ret, ep_len, ep_ret_clean = agent.env.reset(), 0, 0, 0

            # Update handling
            if t >= update_after and t % update_every == 0:
                for j in range(update_every):
                    batch = agent.replay_buffer.sample_batch(batch_size)
                    if agent.name == 'td3':
                        agent.update(data=batch, timer = j)
                    elif agent.name == 'ddpg' or 'sac':
                        agent.update(data=batch)

            if (t + 1) % steps_per_epoch == 0:
                epoch = (t + 1) // steps_per_epoch
                # Test the performance of the deterministic version of the agent.

                distances.append(distance)
                budget_spends.append(budget_spend)
                print('epoch:', epoch, 'distance:', distance, 'budget_spend:', budget_spend, 'budget_left:', C)

                distance = 0
                budget_spend = 0
                distance_epoch = []

        distance_log.append(distances)
        budget_spend_log.append(budget_spends)
        print('total_distance:', np.sum(distances))
        agent.reset()
        [C, B] = atk_params
        C = C * total_steps * B
    data = dict()
    data['distance'] = distance_log
    data['budget spend']=budget_spend_log
    print('total_distance_exp:', np.sum([np.sum(distances) for distances in distance_log]))
    with open(os.path.join(dir,'outputs.json'), 'w') as f:
        f.write(json.dumps(data))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--alg', type=str, default='ddpg')
    parser.add_argument('--env', type=str, default='HalfCheetah-v2')
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=150)
    parser.add_argument("--atk_params", nargs="+", type=float)
    parser.add_argument('--atk_alg', type=str, default=None)
    parser.add_argument("--atk_pfm", type=int, default=-2000)
    parser.add_argument("--n_runs", type=int, default=1)
    parser.add_argument("--steps_per_epoch", type=int, default=4000)
    parser.add_argument("--dir", type=str, default='../tmp')
    parser.add_argument("--rad", type=float, default=None)

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.dir != '../tmp':
        redirect_stdout(open(os.path.join(args.dir, 'outputs.txt'), 'w'))
    else:
        t = int(time.time())
        redirect_stdout(open(os.path.join(args.dir, 'outputs_%s_%s_%s_%d.txt' % (args.env, args.alg, args.atk_pfm, t)), 'w'))

    print('seed:', args.seed)
    print('alg:', args.alg)
    print('env_name:', args.env)
    print('atk_alg:', args.atk_alg)
    print('atk_pfm:', args.atk_pfm)
    print('atk_params:', args.atk_params)
    agent = eval(args.alg.upper())(env_name=args.env, ac_kwargs=dict(hidden_sizes=[args.hid] * args.l), gamma=args.gamma)
    agent.env.seed(args.seed)

    atk_agent = eval(args.atk_alg.upper())(env_name=args.env, ac_kwargs=dict(hidden_sizes=[args.hid] * args.l),
                                           gamma=args.gamma)
    pfm = args.atk_pfm
    if pfm > -1000:
        if args.atk_alg == 'ddpg':
            atk_agent.ac.pi.load_state_dict(
                torch.load('../models/%s_model/%s_pi_%d' % (args.atk_alg, args.env, pfm)))
            atk_agent.ac.q.load_state_dict(
                torch.load('../models/%s_model/%s_q_%d' % (args.atk_alg, args.env, pfm)))
        elif args.atk_alg == 'td3' or 'sac':
            atk_agent.ac.pi.load_state_dict(
                torch.load('../models/%s_model/%s_pi_%d' % (args.atk_alg, args.env, pfm)))
            atk_agent.ac.q1.load_state_dict(
                torch.load('../models/%s_model/%s_q1_%d' % (args.atk_alg, args.env, pfm)))
            atk_agent.ac.q2.load_state_dict(
                torch.load('../models/%s_model/%s_q2_%d' % (args.atk_alg, args.env, pfm)))

    training_target_atk(agent=agent, dir = args.dir, atk_agent=atk_agent, atk_params=args.atk_params, steps_per_epoch=args.steps_per_epoch,
                        epochs=args.epochs, n_runs=args.n_runs, rad=args.rad)






