
import os
import sys
import glob
import time
from datetime import datetime
import warnings
warnings.simplefilter("ignore", UserWarning)


import torch
import torch.nn as nn
from torch.distributions import Normal
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
from torch.utils.data import TensorDataset, DataLoader

import itertools
import math
import numpy as np

################################## set device ##################################
# print("============================================================================================")
# # set device to cpu or cuda
# device = torch.device('cpu')
# if(torch.cuda.is_available()):
#     device = torch.device('cuda:0')
#     torch.cuda.empty_cache()
#     print("Device set to : " + str(torch.cuda.get_device_name(device)))
# else:
#     print("Device set to : cpu")
# print("============================================================================================")


################################## PPO Policy ##################################

class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.nextstates = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.nextstates[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]


class ReplayBuffer():
    def __init__(self, action_prob_exist, max_size, state_dim, num_action):
        self.max_size = max_size
        self.data_idx = 0
        self.action_prob_exist = action_prob_exist
        self.data = {}

        self.data['state'] = np.zeros((self.max_size, state_dim))
        self.data['action'] = np.zeros((self.max_size))
        self.data['reward'] = np.zeros((self.max_size, 1))
        self.data['next_state'] = np.zeros((self.max_size, state_dim))
        self.data['done'] = np.zeros((self.max_size, 1))
        if self.action_prob_exist:
            self.data['log_prob'] = np.zeros((self.max_size, 1))

    def put_data(self, transition):
        idx = self.data_idx % self.max_size
        self.data['state'][idx] = transition['state']
        self.data['action'][idx] = transition['action']
        self.data['reward'][idx] = transition['reward']
        self.data['next_state'][idx] = transition['next_state']
        self.data['done'][idx] = float(transition['done'])
        if self.action_prob_exist:
            self.data['log_prob'][idx] = transition['log_prob']

        self.data_idx += 1

    def sample(self, shuffle, batch_size=None):
        if shuffle:
            sample_num = min(self.max_size, self.data_idx)
            rand_idx = np.random.choice(sample_num, batch_size, replace=False)
            sampled_data = {}
            sampled_data['state'] = self.data['state'][rand_idx]
            sampled_data['action'] = self.data['action'][rand_idx]
            sampled_data['reward'] = self.data['reward'][rand_idx]
            sampled_data['next_state'] = self.data['next_state'][rand_idx]
            sampled_data['done'] = self.data['done'][rand_idx]
            if self.action_prob_exist:
                sampled_data['log_prob'] = self.data['log_prob'][rand_idx]
            return sampled_data
        else:
            return self.data

    def size(self):
        return min(self.max_size, self.data_idx)


class Pool():
    def __init__(self, max_size, state_dim):
        self.max_size = max_size
        self.data_idx = 0
        self.data = np.zeros((self.max_size, state_dim))

    def put_data(self, state_):
        idx = self.data_idx % self.max_size
        self.data[idx] = state_
        self.data_idx += 1

    def sample(self, shuffle, batch_size=None):
        sample_num = min(self.max_size, self.data_idx)
        rand_idx = np.random.choice(sample_num, batch_size, replace=False)
        sampled_data = self.data[rand_idx]
        return sampled_data

    def size(self):
        return min(self.max_size, self.data_idx)


def make_mini_batch(*value):
    mini_batch_size = value[0]
    full_batch_size = len(value[1])
    full_indices = np.arange(full_batch_size)
    np.random.shuffle(full_indices)
    for i in range(full_batch_size // mini_batch_size):
        indices = full_indices[mini_batch_size * i: mini_batch_size * (i + 1)]
        yield [x[indices] for x in value[1:]]


def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)

def init_norm_layer(input_dim, norm_layer):
    if norm_layer == "batchnorm":
        return nn.BatchNorm1d(input_dim, eps=0, momentum=None,
                              affine=False, track_running_stats=False)
    elif norm_layer is None:
        return nn.Identity()


class BNNMLP(nn.Module):
    def __init__(self, n_in, n_out, W_mu=None, b_mu=None, W_std=None,
                 b_std=None, scaled_variance=False, prior_per='parameter'):
        """Initialization.
        Args:
            n_in: int, the size of the input data.
            n_out: int, the size of the output.
            W_std: float, the initial value of
                the standard deviation of the weights.
            b_std: float, the initial value of
                the standard deviation of the biases.
            prior_per: str, indicates whether using different prior for
                each parameter, option `parameter`, or use the share the
                prior for all parameters in the same layer, option `layer`.
        """
        super(BNNMLP, self).__init__()

        self.n_in = n_in
        self.n_out = n_out
        self.scaled_variance = scaled_variance

        if W_mu is None:
            if self.scaled_variance:
                W_mu = 1.
            else:
                W_mu = 1. / math.sqrt(self.n_in)
        if b_mu is None:
            b_mu = 1.

        if W_std is None:
            if self.scaled_variance:
                W_std = 1.
            else:
                W_std = 1. / math.sqrt(self.n_in)
        if b_std is None:
            b_std = 1.

        if prior_per == "layer":
            W_shape, b_shape = (1), (1)
        elif prior_per == "parameter":
            W_shape, b_shape = [self.n_in, self.n_out], [self.n_out]
        else:
            raise ValueError("Accepted values: `parameter` or `layer`")

        # intial posterior
        bound = 1. / math.sqrt(self.n_in)
        m = torch.distributions.uniform.Uniform(torch.tensor([-bound]), torch.tensor([bound]))
        W_mu_tmp = m.sample(W_shape).squeeze()
        b_mu_tmp = m.sample(b_shape).squeeze()

        # W_mu_tmp = torch.zeros(W_shape)
        # b_mu_tmp = torch.zeros(b_shape)

        self.W_mu = nn.Parameter(
            W_mu_tmp, requires_grad=True)
        self.b_mu = nn.Parameter(
            b_mu_tmp, requires_grad=True)

        self.W_std = nn.Parameter(
            torch.zeros(W_shape) + 0.01, requires_grad=True)
        self.b_std = nn.Parameter(
            torch.zeros(b_shape) + 0.01, requires_grad=True)

        # save prior
        self.W_mu_prior = torch.zeros(W_shape)
        self.b_mu_prior = torch.zeros(b_shape)
        self.W_std_prior = torch.ones(W_shape)
        self.b_std_prior = torch.ones(b_shape)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """

        W = self.W_mu + self.W_std * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)
        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + self.b_std * \
            torch.randn((self.n_out), device=self.b_std.device)

        # W_dist = torch.distributions.Normal(self.W_mu, self.W_std)
        # b_dist = torch.distributions.Normal(self.b_mu, self.b_std)
        # W = W_dist.rsample()
        # b = b_dist.rsample()

        output = torch.mm(X, W) + b

        return output

    def sample_predict(self, X, n_samples):
        """Makes predictions using a set of sampled weights.
        Args:
            X: torch.tensor, [n_samples, batch_size, input_dim], the input
                data.
            n_samples: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [n_samples, batch_size, output_dim], the output data.
        """
        X = X.float()

        # W_dist = torch.distributions.Normal(self.W_mu, self.W_std)
        # b_dist = torch.distributions.Normal(self.b_mu, self.b_std)
        # Ws = W_dist.rsample(torch.Size([n_samples]))
        # bs = b_dist.rsample(torch.Size([n_samples, 1]))

        Ws = self.W_mu +self.W_std * \
             torch.randn([n_samples, self.n_in, self.n_out],
                         device=self.W_std.device)
        if self.scaled_variance:
            Ws = Ws / math.sqrt(self.n_in)
        bs = self.b_mu + self.b_std * \
             torch.randn([n_samples, 1, self.n_out],
                         device=self.b_std.device)

        return torch.matmul(X, Ws) + bs


class LipschitzFunction1(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LipschitzFunction1, self).__init__()
        self.lin1 = nn.Linear(input_dim, 1)

    def forward(self, x):
        x = x.float()
        x = self.lin1(x)
        return x

class LipschitzFunction(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LipschitzFunction, self).__init__()
        self.lin1 = nn.Linear(input_dim, hidden_dim)
        self.relu1 = nn.Softplus()
        self.lin2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.Softplus()
        self.lin3 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = x.float()
        x = self.lin1(x)
        x = self.relu1(x)
        x = self.lin2(x)
        x = self.relu2(x)
        x = self.lin3(x)
        return x

class WassersteinDistance():
    def __init__(self,
                 lipschitz_input_dim,
                 lipschitz_hidden_dim,
                 lipschitz_constraint_type="gp",
                 wasserstein_lr=0.01,
                 device='cpu'
                 ):

        self.device = device
        self.lipschitz_input_dim = lipschitz_input_dim
        self.lipschitz_hidden_dim = lipschitz_hidden_dim
        self.lipschitz_constraint_type = lipschitz_constraint_type
        assert self.lipschitz_constraint_type in ["gp", "lp"]

        self.lipschitz_f = LipschitzFunction(lipschitz_input_dim, lipschitz_hidden_dim)
        self.lipschitz_f = self.lipschitz_f.to(self.device)
        # self.values_log = []

        self.wasserstein_lr= wasserstein_lr
        self.optimiser = torch.optim.Adam(self.lipschitz_f.parameters(), lr=wasserstein_lr)
        self.penalty_coeff = 10

    def calculate(self, nnet_samples, gp_samples):
        self.lipschitz_f.eval()
        f_nnet = self.lipschitz_f(nnet_samples)
        f_gp = self.lipschitz_f(gp_samples)
        d = torch.mean(torch.abs(torch.mean(f_nnet, 0) - torch.mean(f_gp, 0)))
        return d

    def calculate_gradientnorm(self, nnet_samples, nnet_num, gp_samples, gp_num):
        #
        f_nnet = self.lipschitz_f(nnet_samples)
        f_gp = self.lipschitz_f(gp_samples)
        d = torch.mean(torch.abs(torch.mean(f_nnet, 0) - torch.mean(f_gp, 0)))

        # all_samples = torch.cat((nnet_samples, gp_samples), 0).detach()
        eps = torch.rand(nnet_samples.shape).to(nnet_samples.device)
        all_samples = eps * nnet_samples + (1 - eps) * gp_samples
        all_samples.detach()
        all_samples.requires_grad = True
        f_all_samples = self.lipschitz_f(all_samples)
        gradients = torch.autograd.grad(f_all_samples, all_samples, create_graph=True, retain_graph=True,only_inputs=True,
                                        grad_outputs=torch.ones(f_all_samples.size(), device=self.device))[0]
        f_gradient_norm = gradients.norm(2, dim=1)

        if self.lipschitz_constraint_type == "gp":
            # Gulrajani2017, Improved Training of Wasserstein GANs
            penalty = ((f_gradient_norm - 1) ** 2).mean()

        elif self.lipschitz_constraint_type == "lp":
            # Henning2018, On the Regularization of Wasserstein GANs
            # Eq (8) in Section 5
            penalty = ((torch.clamp(f_gradient_norm - 1, 0., np.inf)) ** 2).mean()
        else:
            penalty = 0

        # d = torch.mean(
        #     torch.abs(torch.mean(f_all_samples[:nnet_num, :, :], 0) - torch.mean(f_all_samples[gp_num:, :, :], 0)))
        return -d + self.penalty_coeff * penalty, d

    def calculate_gradientnorm1(self, nnet_samples, nnet_num, gp_samples, gp_num):

        f_nnet = self.lipschitz_f(nnet_samples)
        f_gp = self.lipschitz_f(gp_samples)
        d = torch.mean(torch.abs(torch.mean(f_nnet, 0) - torch.mean(f_gp, 0)))

        f_gradient_norm = self.lipschitz_f.lin1.weight.norm()

        if self.lipschitz_constraint_type == "gp":
            # Gulrajani2017, Improved Training of Wasserstein GANs
            penalty = ((f_gradient_norm - 1) ** 2).mean()

        elif self.lipschitz_constraint_type == "lp":
            # Henning2018, On the Regularization of Wasserstein GANs
            # Eq (8) in Section 5
            penalty = ((torch.clamp(f_gradient_norm - 1, 0., np.inf)) ** 2).mean()
        else:
            penalty = 0

        return -d + self.penalty_coeff * penalty, d

    def wasserstein_optimisation(self, n_nnet, nnet_samples, m_gp, gp_samples, n_steps=10, threshold=None):

        # self.optimiser = torch.optim.Adam(self.lipschitz_f.parameters(), lr=self.wasserstein_lr) #TODO: remove it
        self.lipschitz_f.train()

        for i in range(n_steps):
            objective, d = self.calculate_gradientnorm(nnet_samples, n_nnet, gp_samples, m_gp)
            # print(" i = ", i, " wdist=", d, "  objective=", objective)
            self.optimiser.zero_grad()
            objective.backward(retain_graph=True)
            # objective.backward()
            self.optimiser.step()

            if threshold is not None:
                # Gradient Norm
                params = self.lipschitz_f.parameters()
                grad_norm = torch.cat([p.grad.data.flatten() for p in params]).norm()

                if grad_norm < threshold:
                    self.lipschitz_f.eval()
                    return 0

        # evaluate the distance after optimizing lipschitz_f
        self.lipschitz_f.eval()

        return 1


class BNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 W_mu=None, b_mu=None, W_std=None, b_std=None, scaled_variance=False, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
            W_std: float, the initial value of the logarithm of
                the standard deviation of the weights.
            b_std: float, the initial value of the logarithm of
                the standard deviation of the biases.
        """
        super(BNN, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        # # Initialize layers
        self.input_layer = nn.Linear(input_dim, hidden_dims[0])

        # self.input_layer = BNNMLP(
        #     input_dim, hidden_dims[0], W_mu, b_mu, W_std, b_std,
        #     scaled_variance=scaled_variance)

        self.norm_layer1 = init_norm_layer(hidden_dims[0], self.norm_layer)

        self.mid_layer = BNNMLP(
            hidden_dims[0], hidden_dims[1], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)

        self.norm_layer2 = init_norm_layer(hidden_dims[1], self.norm_layer)

        self.output_layer = BNNMLP(
            hidden_dims[1], output_dim, W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)

        # self.final_output_layer = nn.Softmax(dim=-1)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        # X = self.activation_fn(self.input_layer(X))
        # X = self.activation_fn(self.mid_layer(X))
        # X = self.output_layer(X).squeeze(0)

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        X = self.output_layer(X).squeeze(0)
        # X = self.final_output_layer(X)

        return X

    def sample_functions(self, X, n_samples):
        """Performs predictions using `n_samples` set of weights.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            n_samples: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [batch_size, n_samples, output_dim], the output
            data.
        """
        X = X.view(-1, self.input_dim)
        X = torch.unsqueeze(X, 0).repeat([n_samples, 1, 1])

        # X = self.activation_fn(self.input_layer(X))
        # X = self.activation_fn(self.mid_layer.sample_predict(X, n_samples))
        # X = self.output_layer.sample_predict(X, n_samples)
        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.sample_predict(X, n_samples)))
        X = self.output_layer.sample_predict(X, n_samples)

        return X

class BNNActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, activation_fn, has_continuous_action_space, device, norm_layer=None):
        super(BNNActorCritic, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space
        self.device = device
        self.norm_layer = norm_layer

        # actor
        self.action_dim = action_dim
        self.action_logstd = nn.Parameter(-1.0 * torch.ones(action_dim))
        self.action_logstd.to(self.device)

        n_units = 64
        n_hidden = 2
        hidden_dims = [n_units] * n_hidden

        self.actor = BNN(state_dim, action_dim, hidden_dims, activation_fn, has_continuous_action_space,
                         scaled_variance=True).to(self.device)
        self.actor_prior = BNN(state_dim, action_dim, hidden_dims, activation_fn, has_continuous_action_space,
                         scaled_variance=True).to(self.device)
        self.actor_prior.eval()

        # critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, n_units),
            nn.Tanh(),
            nn.Linear(n_units, n_units),
            nn.Tanh(),
            nn.Linear(n_units, 1)
        )

    def forward(self):
        raise NotImplementedError

    def get_action(self, state):
        action_mean = self.actor(state)
        action_std = torch.exp(self.action_logstd)

        return action_mean, action_std

    def select_action(self, state):
        with torch.no_grad():
            action_mean = self.actor(state)
            action_std = torch.exp(self.action_logstd)

        return action_mean, action_std

    def v(self, state):
        return self.critic(state)


class fwPPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip,
                 has_continuous_action_space, max_ep_len,
                 dist_coeff=1.0, prior_coeff=1.0, magl_coeff=1.0, n_samples=10, activation_fn='tanh', device='cpu'):

        self.has_continuous_action_space = has_continuous_action_space

        self.gamma = gamma
        self.gae = False
        self.lambda_ = 0.95
        self.entropy_coef = 0.01

        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = 64
        self.critic_coef = 0.5
        self.max_grad_norm = 0.5

        self.n_samples = n_samples
        self.m_samples = n_samples

        self.device = device

        self.dist_coeff = dist_coeff
        self.prior_coeff = prior_coeff
        self.magl_coeff = magl_coeff

        self.activation_fn = activation_fn
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_ep_len = max_ep_len

        self.buffer = RolloutBuffer()
        # self.data = ReplayBuffer(action_prob_exist=True, max_size=max_ep_len, state_dim=state_dim,
        #                          num_action=action_dim)

        self.pool = Pool(max_size=max_ep_len*5, state_dim=state_dim)

        self.policy = BNNActorCritic(state_dim, action_dim, activation_fn, has_continuous_action_space,
                                     device=self.device).to(self.device)
        self.actor_optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor}
            # ,{'params': self.policy.action_logstd, 'lr': lr_actor}
        ])
        # self.action_logstd_optimizer = torch.optim.Adam([self.policy.action_logstd], lr=lr_actor)
        self.critic_optimizer = torch.optim.Adam([{'params': self.policy.critic.parameters(), 'lr': lr_critic}])

        self.critic_loss = nn.SmoothL1Loss()

        self.policy_old = BNNActorCritic(state_dim, action_dim, activation_fn, has_continuous_action_space,
                                         device=self.device).to(self.device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.policy_old.eval()

        # for wasserstein distance
        self.lipschitz_input_dim = action_dim
        self.lipschitz_hidden_dim = 64  # 64 before

        self.lipschitz_constraint_type = "gp"
        self.wasserstein_steps = 10 #10
        self.wasserstein_threshold = 0.01
        self.wasserstein_lr = 0.01

        # Initialize the module of wasserstance distance with GP
        self.wasserstein = WassersteinDistance(self.lipschitz_input_dim, self.lipschitz_hidden_dim,
                                               lipschitz_constraint_type=self.lipschitz_constraint_type,
                                                wasserstein_lr=self.wasserstein_lr, device=self.device)

        # Initialize of lipschitz_f
        self.wasserstein.lipschitz_f.apply(weights_init)

        # Initialize the module of wasserstance distance with old policy
        self.wasserstein_old = WassersteinDistance(self.lipschitz_input_dim, self.lipschitz_hidden_dim,
                                               lipschitz_constraint_type=self.lipschitz_constraint_type,
                                               wasserstein_lr=self.wasserstein_lr, device=self.device)

        # Initialize of lipschitz_f for old policy distance
        self.wasserstein_old.lipschitz_f.apply(weights_init)

        # Initialize the module of wasserstance distance between marginals
        self.wasserstein_marginal = WassersteinDistance(self.lipschitz_input_dim, self.lipschitz_hidden_dim,
                                               lipschitz_constraint_type=self.lipschitz_constraint_type,
                                               wasserstein_lr=self.wasserstein_lr, device=self.device)

        # Initialize of lipschitz_f for marginal distance
        self.wasserstein_marginal.lipschitz_f.apply(weights_init)

    def put_data_pool(self, state):
        self.pool.put_data(state)

    def update(self, time_step, writer):

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(self.device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(self.device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(self.device)
        old_rewards = torch.squeeze(torch.stack(self.buffer.rewards, dim=0)).detach().to(self.device)
        # old_nextstates = torch.squeeze(torch.stack(self.buffer.nextstates, dim=0)).detach().to(self.device)
        old_is_terminals = torch.squeeze(torch.stack(self.buffer.is_terminals, dim=0)).detach().to(self.device)
        num_buffer = len(old_states)

        with torch.no_grad():
            old_values = self.policy.v(old_states).float().squeeze().to(self.device)
            next_value = self.policy.v(self.buffer.nextstates[-1]).to(self.device)
            if self.gae:
                advantages = torch.zeros_like(old_rewards).to(self.device)
                lastgaelam = 0
                for t in reversed(range(num_buffer)):
                    if t == num_buffer - 1:
                        nextnonterminal = 1.0 - torch.tensor([self.buffer.is_terminals[-1]]).to(self.device)
                        nextvalues = next_value
                    else:
                        nextnonterminal = 1.0 - old_is_terminals[t + 1]
                        nextvalues = old_values[t + 1]
                    delta = old_rewards[t] + self.gamma * nextvalues * nextnonterminal - old_values[t]
                    advantages[t] = lastgaelam = delta + self.gamma * self.lambda_ * nextnonterminal * lastgaelam
                returns = advantages + old_values
            else:
                returns = torch.zeros_like(old_rewards).to(self.device)
                for t in reversed(range(num_buffer)):
                    if t == num_buffer - 1:
                        nextnonterminal = 1.0 - torch.tensor([self.buffer.is_terminals[-1]]).to(self.device)
                        next_return = next_value
                    else:
                        nextnonterminal = 1.0 - old_is_terminals[t + 1]
                        next_return = returns[t + 1]
                    returns[t] = old_rewards[t] + self.gamma * nextnonterminal * next_return
                advantages = returns - old_values

        #
        for i in range(self.K_epochs):
            # end_time = datetime.now()

            curr_mu, curr_sigma = self.policy.get_action(old_states)
            curr_dist = torch.distributions.Normal(curr_mu.squeeze(), curr_sigma.squeeze())
            entropy = curr_dist.entropy().sum(-1)
            curr_log_prob = curr_dist.log_prob(old_actions).sum(1, keepdim=False)
            new_values = self.policy.v(old_states).squeeze()

            # policy clipping
            ratio = torch.exp(curr_log_prob - old_logprobs.detach())
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # Evaluating old actions and values, and evaluate the kl divergence between posetior and prior

            # start_time = datetime.now()
            # print("surrogate term using time  : ", start_time - end_time)

            # Draw a diverse measurement subset from a pool
            # print("epoch: ", k, ",  draw a diverse measurement subset ")
            states_mset = torch.from_numpy(self.pool.sample(shuffle=True, batch_size=self.batch_size)).float().to(
                self.device)

            # Combine the measurement set with current batch
            states_combined = torch.cat((old_states, states_mset))
            # states_combined = state

            # end_time = datetime.now()
            # print("diverse measurement subset using time  : ", end_time - start_time)

            # Draw functions from prior policy, old policy, and current policy
            # size of nnet_samples is [n_samples, batch_size, output_dim/action_dim]
            nnet_samples = self.policy.actor.sample_functions(states_combined, self.n_samples)
            prior_nnet_samples = self.policy.actor_prior.sample_functions(states_combined, self.n_samples)
            old_nnet_samples = self.policy_old.actor.sample_functions(states_combined, self.n_samples)
            nnet_samples_m = self.policy.actor.sample_functions(states_combined, self.m_samples)

            # start_time = datetime.now()
            # print("Draw functions from prior policy, old policy, and current policy using time  : ", start_time - end_time)

            # Optimise lipschitz_f and evaluate the Wasserstein distance
            _ = self.wasserstein.wasserstein_optimisation(self.n_samples, prior_nnet_samples.detach(),
                                                                 self.n_samples, nnet_samples.detach(),
                                                                 n_steps=self.wasserstein_steps,
                                                                 threshold=self.wasserstein_threshold)
            fw_prior = (torch.mean(self.wasserstein.calculate(prior_nnet_samples, nnet_samples))) ** 2
            # fw_prior = 0
            # print("fw_prior = ", fw_prior.mean())
            # end_time = datetime.now()
            # print("fw_prior using time  : ", end_time - start_time)

            # Optimise lipschitz_f and evaluate the Wasserstein distance with old policy distance
            _ = self.wasserstein_old.wasserstein_optimisation(self.n_samples, old_nnet_samples.detach(),
                                                                   self.n_samples, nnet_samples.detach(),
                                                                   n_steps=self.wasserstein_steps,
                                                                   threshold=self.wasserstein_threshold)
            fw_old = (torch.mean(self.wasserstein_old.calculate(old_nnet_samples, nnet_samples))) ** 2
            # fw_old = 0
            # print("fw_old = ", fw_old.mean())
            # start_time = datetime.now()
            # print("fw_old using time  : ", start_time - end_time)

            # Optimise lipschitz_f and evaluate the Wasserstein distance between marginal distributions
            _ = self.wasserstein_marginal.wasserstein_optimisation(self.m_samples, nnet_samples_m.detach(),
                                                                             self.n_samples, nnet_samples.detach(),
                                                                             n_steps=self.wasserstein_steps,
                                                                             threshold=self.wasserstein_threshold)
            fw_marginal = torch.mean(self.wasserstein_marginal.calculate(nnet_samples_m, nnet_samples))
            # fw_marginal = 0
            # print("fw_marginal = ", fw_marginal.mean())
            # end_time = datetime.now()
            # print("fw_marginal using time  : ", end_time - start_time)

            #
            actor_loss = (-torch.min(surr1, surr2) - self.entropy_coef * entropy).mean()
            actor_loss = actor_loss + self.prior_coeff * fw_prior + self.dist_coeff * fw_old + self.magl_coeff * fw_marginal

            # print("other = ", (-torch.min(surr1,surr2) - entropy).mean())

            # critic_loss = self.critic_coef * ((new_values - returns).pow(2).mean())
            critic_loss = self.critic_coef * self.critic_loss(new_values, returns)

            writer.add_scalar("actor_loss", actor_loss, time_step)
            writer.add_scalar("critic_loss", critic_loss, time_step)
            # writer.add_scalar("tmp_loss", tmp_loss, time_step)
            writer.add_scalar("fw_prior", fw_prior, time_step)
            writer.add_scalar("fw_old", fw_old, time_step)
            writer.add_scalar("fw_marginal", fw_marginal, time_step)
            writer.add_scalar("action_logstd", torch.exp(self.policy.action_logstd).mean(), time_step)
            writer.flush()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            # nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
            self.actor_optimizer.step()

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            # nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
            self.critic_optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()

    def save(self, checkpoint_path):
        torch.save(self.policy.state_dict(), checkpoint_path)
   
    def load(self, checkpoint_path):
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        
        
       


