import numpy as np
import torch
import gymnasium as gym
import argparse
import os
import copy
import torch.nn as nn
import torch.nn.functional as F
import math

USE_POISSON=False
envs = {'HalfCheetah-v4'}  
critic_lr = 3e-4
max_timestep = 1e6+25e3
seed =3

BN_update_freq=5000
BN_update_time=100

ENCODER_REGULAR_VTH = 0.999
NEURON_VTH = 0.5
NEURON_CDECAY = 1 / 2
NEURON_VDECAY = 3 / 4
SPIKE_PSEUDO_GRAD_WINDOW = 0.5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



def batch_norm_update(X, moving_mean, moving_var, num1, num2):
    assert len(X.shape) in (2, 4)
    if len(X.shape) == 2:
        mean = X.mean(dim=0)
        var = ((X - mean) ** 2).mean(dim=0)
    else:
        mean = X.mean(dim=(0, 2, 3), keepdim=True)
        var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    moving_var = num1 / (num1 + num2) * moving_var + num2 / (num1 + num2) * var + num1*num2/(num1 + num2)/(num1+num2)*torch.square(moving_mean-mean)
    moving_mean = num1/(num1 + num2 ) * moving_mean + num2/(num1 + num2)*mean
    return moving_mean, moving_var

class BatchNorm(nn.Module):
    def __init__(self, num_features,spike_ts, momentum=0.9, num_dims=2):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        self.gamma = nn.Parameter(torch.ones(shape)*NEURON_VTH/2.0)
        self.beta = nn.Parameter(torch.zeros(shape)+NEURON_VTH/2.0)
        self.moving_mean = nn.Parameter(torch.zeros(shape),requires_grad=False)
        self.moving_var = nn.Parameter(torch.ones(shape),requires_grad=False)
        self.temp_mean = torch.zeros(shape)
        self.temp_var = torch.ones(shape)
        self.p_mean = torch.zeros(shape)
        self.p_var = torch.zeros(shape)
        self.K_mean = torch.zeros(shape)
        self.K_var = torch.zeros(shape)
        self.spike_ts=spike_ts
        self.eps=1e-5

    def forward(self, X, update, BN_update):
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        if self.temp_mean.device != X.device:
            self.temp_mean = self.temp_mean.to(X.device)
            self.temp_var = self.temp_var.to(X.device)
        if self.p_mean.device != X.device:
            self.p_mean = self.p_mean.to(X.device)
            self.p_var = self.p_var.to(X.device)
        if self.K_mean.device != X.device:
            self.K_mean = self.K_mean.to(X.device)
            self.K_var = self.K_var.to(X.device)
        X_hat_list = []
        Y_list = []

        if BN_update:
            with torch.no_grad():
                if BN_update[0]==0:
                    self.temp_mean = torch.zeros_like(self.temp_mean)
                    self.temp_var = torch.zeros_like(self.temp_var)
                for step in range(self.spike_ts):
                    self.temp_mean, self.temp_var = batch_norm_update(X[:, :, step], self.temp_mean, self.temp_var, step + BN_update[0] * self.spike_ts, 1)

                if BN_update[1]==BN_update[0]+1:
                    self.moving_mean = nn.Parameter(self.temp_mean,requires_grad=False)
                    self.moving_var = nn.Parameter(self.temp_var,requires_grad=False)


                for step in range(self.spike_ts):
                    X_hat_step = (X[:, :, step] - self.temp_mean) / torch.sqrt(self.temp_var + self.eps)
                    Y_step = self.gamma * X_hat_step + self.beta


                    X_hat_list.append(X_hat_step)
                    Y_list.append(Y_step)

        elif update:
            self.temp_mean=torch.zeros_like(self.temp_mean)
            self.temp_var=torch.zeros_like(self.temp_var)
            for step in range(self.spike_ts):
                self.temp_mean,self.temp_var=batch_norm_update(X[:,:,step], self.temp_mean, self.temp_var, step, 1)
            
            with torch.no_grad():
                delta_mean = self.temp_mean - self.moving_mean
                delta_var = self.temp_var - self.moving_var
                mean_var = self.temp_var / 255
                var_var = 2 * torch.square(self.temp_var) / 255
                self.p_mean = self.p_mean * 0.8 + 0.2 * torch.square(delta_mean)
                self.p_var = self.p_var * 0.8 + 0.2 * torch.square(delta_var)
                self.K_mean = self.p_mean / (self.p_mean + mean_var)
                self.K_var = self.p_var / (self.p_var + var_var)
                self.moving_mean = nn.Parameter(self.moving_mean + self.K_mean * delta_mean, requires_grad=False)
                self.moving_var = nn.Parameter(self.moving_var + self.K_var * delta_var, requires_grad=False)
                    
            for step in range(self.spike_ts):
                X_hat_step = (X[:, :, step] - self.temp_mean) / torch.sqrt(self.temp_var + self.eps)
                Y_step = self.gamma * X_hat_step + self.beta

                X_hat_list.append(X_hat_step)
                Y_list.append(Y_step)
        else:
            for step in range(self.spike_ts):
                X_hat_step = (X[:, :, step] - self.moving_mean) / torch.sqrt(self.moving_var + self.eps)
                Y_step = self.gamma * X_hat_step + self.beta

        
                X_hat_list.append(X_hat_step)
                Y_list.append(Y_step)
        Y = torch.stack(Y_list, dim=2)
        return Y



class ReplayBuffer(object):
	def __init__(self, state_dim, action_dim, max_size=int(1e6)):
		self.max_size = max_size
		self.ptr = 0
		self.size = 0

		self.state = np.zeros((max_size, state_dim))
		self.action = np.zeros((max_size, action_dim))
		self.next_state = np.zeros((max_size, state_dim))
		self.reward = np.zeros((max_size, 1))
		self.not_done = np.zeros((max_size, 1))

		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


	def add(self, state, action, next_state, reward, done):
		self.state[self.ptr] = state
		self.action[self.ptr] = action
		self.next_state[self.ptr] = next_state
		self.reward[self.ptr] = reward
		self.not_done[self.ptr] = 1. - done

		self.ptr = (self.ptr + 1) % self.max_size
		self.size = min(self.size + 1, self.max_size)


	def sample(self, batch_size):
		ind = np.random.randint(0, self.size, size=batch_size)

		return (
			torch.FloatTensor(self.state[ind]).to(self.device),
			torch.FloatTensor(self.action[ind]).to(self.device),
			torch.FloatTensor(self.next_state[ind]).to(self.device),
			torch.FloatTensor(self.reward[ind]).to(self.device),
			torch.FloatTensor(self.not_done[ind]).to(self.device)
		)



class PseudoEncoderSpikeRegular(torch.autograd.Function):
    """ Pseudo-gradient function for spike - Regular Spike for encoder """
    @staticmethod
    def forward(ctx, input):
        return input.gt(ENCODER_REGULAR_VTH).float()
    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input


class PseudoEncoderSpikePoisson(torch.autograd.Function):
    """ Pseudo-gradient function for spike - Poisson Spike for encoder """
    @staticmethod
    def forward(ctx, input):
        return torch.bernoulli(input).float()
    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input

class PopSpikeEncoderRegularSpike(nn.Module):
    """ Learnable Population Coding Spike Encoder with Regular Spike Trains """
    def __init__(self, obs_dim, pop_dim, spike_ts, mean_range, std, device):
        """
        :param obs_dim: observation dimension
        :param pop_dim: population dimension
        :param spike_ts: spike timesteps
        :param mean_range: mean range
        :param std: std
        :param device: device
        """
        super().__init__()
        self.obs_dim = obs_dim
        self.pop_dim = pop_dim
        self.encoder_neuron_num = obs_dim * pop_dim
        self.spike_ts = spike_ts
        self.device = device
        self.pseudo_spike = PseudoEncoderSpikeRegular.apply
        # Compute evenly distributed mean and variance
        tmp_mean = torch.zeros(1, obs_dim, pop_dim)
        delta_mean = (mean_range[1] - mean_range[0]) / (pop_dim - 1)
        for num in range(pop_dim):
            tmp_mean[0, :, num] = mean_range[0] + delta_mean * num
        tmp_std = torch.zeros(1, obs_dim, pop_dim) + std
        self.mean = nn.Parameter(tmp_mean)
        self.std = nn.Parameter(tmp_std)

    def forward(self, obs, batch_size):
        """
        :param obs: observation
        :param batch_size: batch size
        :return: pop_spikes
        """
        obs = obs.view(-1, self.obs_dim, 1)
        # Receptive Field of encoder population has Gaussian Shape
        pop_act = torch.exp(-(1. / 2.) * (obs - self.mean).pow(2) / self.std.pow(2)).view(-1, self.encoder_neuron_num)
        pop_volt = torch.zeros(batch_size, self.encoder_neuron_num, device=self.device)
        pop_spikes = torch.zeros(batch_size, self.encoder_neuron_num, self.spike_ts, device=self.device)
        # Generate Regular Spike Trains
        for step in range(self.spike_ts):
            pop_volt = pop_volt + pop_act
            pop_spikes[:, :, step] = self.pseudo_spike(pop_volt)
            pop_volt = pop_volt - pop_spikes[:, :, step] * ENCODER_REGULAR_VTH
        return pop_spikes


class PopSpikeEncoderPoissonSpike(PopSpikeEncoderRegularSpike):
    """ Learnable Population Coding Spike Encoder with Poisson Spike Trains """
    def __init__(self, obs_dim, pop_dim, spike_ts, mean_range, std, device):
        """
        :param obs_dim: observation dimension
        :param pop_dim: population dimension
        :param spike_ts: spike timesteps
        :param mean_range: mean range
        :param std: std
        :param device: device
        """
        super().__init__(obs_dim, pop_dim, spike_ts, mean_range, std, device)
        self.pseudo_spike = PseudoEncoderSpikePoisson.apply

    def forward(self, obs, batch_size):
        """
        :param obs: observation
        :param batch_size: batch size
        :return: pop_spikes
        """
        obs = obs.view(-1, self.obs_dim, 1)
        # Receptive Field of encoder population has Gaussian Shape
        pop_act = torch.exp(-(1. / 2.) * (obs - self.mean).pow(2) / self.std.pow(2)).view(-1, self.encoder_neuron_num)
        pop_spikes = torch.zeros(batch_size, self.encoder_neuron_num, self.spike_ts, device=self.device)
        # Generate Poisson Spike Trains
        for step in range(self.spike_ts):
            pop_spikes[:, :, step] = self.pseudo_spike(pop_act)
        return pop_spikes


class PopSpikeDecoder(nn.Module):
    """ Population Coding Spike Decoder """
    def __init__(self, act_dim, pop_dim, output_activation=nn.Tanh):
        """
        :param act_dim: action dimension
        :param pop_dim:  population dimension
        :param output_activation: activation function added on output
        """
        super().__init__()
        self.act_dim = act_dim
        self.pop_dim = pop_dim
        self.decoder = nn.Conv1d(act_dim, act_dim, pop_dim, groups=act_dim)
        self.output_activation = output_activation()

    def forward(self, pop_act):
        """
        :param pop_act: output population activity
        :return: raw_act
        """
        pop_act = pop_act.view(-1, self.act_dim, self.pop_dim)
        raw_act = self.output_activation(self.decoder(pop_act).view(-1, self.act_dim))
        return raw_act


class PseudoSpikeRect(torch.autograd.Function):
    """ Pseudo-gradient function for spike - Derivative of Rect Function """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(NEURON_VTH).float()
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        spike_pseudo_grad = (abs(input - NEURON_VTH) < SPIKE_PSEUDO_GRAD_WINDOW)
        return grad_input * spike_pseudo_grad.float()


class SpikeMLP(nn.Module):
    """ Spike MLP with Input and Output population neurons """
    def __init__(self, in_pop_dim, out_pop_dim, hidden_sizes, spike_ts, device):
        """
        :param in_pop_dim: input population dimension
        :param out_pop_dim: output population dimension
        :param hidden_sizes: list of hidden layer sizes
        :param spike_ts: spike timesteps
        :param device: device
        """
        super().__init__()
        self.in_pop_dim = in_pop_dim
        self.out_pop_dim = out_pop_dim
        self.hidden_sizes = hidden_sizes
        self.hidden_num = len(hidden_sizes)
        self.spike_ts = spike_ts
        self.device = device
        self.pseudo_spike = PseudoSpikeRect.apply
        # Define Layers (Hidden Layers + Output Population)
        self.hidden_layers = nn.ModuleList([nn.Linear(in_pop_dim, hidden_sizes[0])])
        self.hidden_norms = nn.ModuleList([BatchNorm(hidden_sizes[0],spike_ts)])
        if self.hidden_num > 1:
            for layer in range(1, self.hidden_num):
                self.hidden_layers.extend([nn.Linear(hidden_sizes[layer-1], hidden_sizes[layer])])
                self.hidden_norms.extend([BatchNorm(hidden_sizes[layer],spike_ts)])
        self.out_pop_layer = nn.Linear(hidden_sizes[-1], out_pop_dim)
        self.out_pop_norm = BatchNorm(out_pop_dim,spike_ts)

    def neuron_model(self, pre_layer_output, current, volt, spike):
        """
        LIF Neuron Model
        :param syn_func: synaptic function
        :param pre_layer_output: output from pre-synaptic layer
        :param current: current of last step
        :param volt: voltage of last step
        :param spike: spike of last step
        :return: current, volt, spike
        """
        current = current * NEURON_CDECAY + pre_layer_output
        volt = volt * NEURON_VDECAY * (1. - spike) + current
        spike = self.pseudo_spike(volt)
        return current, volt, spike

    def forward(self, in_pop_spikes, batch_size, norm_update, BN_update):
        """
        :param in_pop_spikes: input population spikes
        :param batch_size: batch size
        :return: out_pop_act
        """
        # Define LIF Neuron states: Current, Voltage, and Spike
        hidden_states = []
        out_spikes = []
        X=[]
        X_=[]
        for layer in range(self.hidden_num):
            hidden_states.append([torch.zeros(batch_size, self.hidden_sizes[layer], device=self.device)
                                  for _ in range(3)])
            out_spikes.append(torch.zeros(batch_size, self.hidden_sizes[layer], self.spike_ts, device=self.device))
            X.append(torch.zeros(batch_size, self.hidden_sizes[layer], self.spike_ts, device=self.device))
            X_.append(torch.zeros(batch_size, self.hidden_sizes[layer], self.spike_ts, device=self.device))
        out_pop_states = [torch.zeros(batch_size, self.out_pop_dim, device=self.device)
                          for _ in range(3)]
        X.append(torch.zeros(batch_size, self.out_pop_dim, self.spike_ts, device=self.device))
        X_.append(torch.zeros(batch_size, self.out_pop_dim, self.spike_ts, device=self.device))
        out_pop_act = torch.zeros(batch_size, self.out_pop_dim, device=self.device)
        # Start Spike Timestep Iteration
        for step in range(self.spike_ts):
            in_pop_spike_t = in_pop_spikes[:, :, step]
            X[0][:,:,step]=self.hidden_layers[0](in_pop_spike_t)
        X_[0]=self.hidden_norms[0](X[0],update=norm_update,BN_update=BN_update)
        for step in range(self.spike_ts):
            hidden_states[0][0], hidden_states[0][1], hidden_states[0][2] = self.neuron_model(X_[0][:, :, step],
                                                                                              hidden_states[0][0],
                                                                                              hidden_states[0][1],
                                                                                              hidden_states[0][2])
            out_spikes[0][:,:,step]=hidden_states[0][2]
        if self.hidden_num > 1:
            for layer in range(1, self.hidden_num):
                for step in range(self.spike_ts):
                    in_pop_spike_t = out_spikes[layer-1][:,:,step]
                    X[layer][:, :, step] = self.hidden_layers[layer](in_pop_spike_t)
                X_[layer]=self.hidden_norms[layer](X[layer],update=norm_update,BN_update=BN_update)
                for step in range(self.spike_ts):
                    hidden_states[layer][0], hidden_states[layer][1], hidden_states[layer][2] = self.neuron_model(X_[layer][:,:,step],
                        hidden_states[layer][0], hidden_states[layer][1], hidden_states[layer][2])
                    out_spikes[layer][:,:,step]=hidden_states[layer][2]
        for step in range(self.spike_ts):
            in_pop_spike_t = out_spikes[-1][:,:,step]
            X[-1][:, :, step] = self.out_pop_layer(in_pop_spike_t)
        X_[-1]=self.out_pop_norm(X[-1],update=norm_update,BN_update=BN_update)
        for step in range(self.spike_ts):
            out_pop_states[0], out_pop_states[1], out_pop_states[2] = self.neuron_model(X_[-1][:,:,step],
                out_pop_states[0], out_pop_states[1], out_pop_states[2])
            out_pop_act += out_pop_states[2]
        out_pop_act = out_pop_act / self.spike_ts
        return out_pop_act


class Actor(nn.Module):
    """ Population Coding Spike Actor with Fix Encoder """
    def __init__(self, obs_dim, act_dim, act_limit, en_pop_dim=10, de_pop_dim=10, hidden_sizes=[256,256],
                 mean_range=(-1,1), std=math.sqrt(0.05), spike_ts=5, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), use_poisson=USE_POISSON):
        """
        :param obs_dim: observation dimension
        :param act_dim: action dimension
        :param en_pop_dim: encoder population dimension
        :param de_pop_dim: decoder population dimension
        :param hidden_sizes: list of hidden layer sizes
        :param mean_range: mean range for encoder
        :param std: std for encoder
        :param spike_ts: spike timesteps
        :param act_limit: action limit
        :param device: device
        :param use_poisson: if true use Poisson spikes for encoder
        """
        super().__init__()
        self.act_limit = act_limit
        if use_poisson:
            self.encoder = PopSpikeEncoderPoissonSpike(obs_dim, en_pop_dim, spike_ts, mean_range, std, device)
        else:
            self.encoder = PopSpikeEncoderRegularSpike(obs_dim, en_pop_dim, spike_ts, mean_range, std, device)
        self.snn = SpikeMLP(obs_dim*en_pop_dim, act_dim*de_pop_dim, hidden_sizes, spike_ts, device)
        self.decoder = PopSpikeDecoder(act_dim, de_pop_dim)


    def forward(self, obs, norm_update=False, BN_update=False):
        """
        :param obs: observation
        :param batch_size: batch size
        :return: action scale with action limit
        """
        batch_size=obs.size()[0]
        in_pop_spikes = self.encoder(torch.tanh(obs), batch_size)
        out_pop_activity = self.snn(in_pop_spikes, batch_size, norm_update, BN_update)
        action = self.act_limit * self.decoder(out_pop_activity)
        return action

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1


class TD3(object):
    def __init__(
            self,
            state_dim,
            action_dim,
            max_action,
            discount=0.99,
            tau=0.005,
            BN_update_time=BN_update_time,
            policy_noise=0.2,
            noise_clip=0.5,
            policy_freq=2
    ):

        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.BN_update_time = BN_update_time
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.total_it = 0

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()

    def train(self, replay_buffer, batch_size=256):
        self.total_it += 1


        # Sample replay buffer
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (
                    torch.randn_like(action) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip)

            next_action = (
                    self.actor_target(next_state) + noise
            ).clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * target_Q

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)
        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:

            # Compute actor losse
            actor_loss = -self.critic.Q1(state, self.actor(state,norm_update=True)).mean()
            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            return critic_loss, -actor_loss, True
        return critic_loss, 0, False

    def BN_update(self,batch_size=256):
        for _ in range(self.BN_update_time):
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
            self.actor(state,BN_update=[_,self.BN_update_time])

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.actor.state_dict(), filename + "_actor")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
        self.actor_target = copy.deepcopy(self.actor)


# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
def eval_policy(policy, env_name, eval_seed, eval_episodes=10):
    eval_env = gym.make(env_name)
    eval_env.reset(seed=eval_seed + 100)

    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        state = state[0]
        while not done:
            action = policy.select_action(np.array(state))
            state, reward, done1, done2, _ = eval_env.step(action)
            done = done1 + done2
            avg_reward += reward

    avg_reward /= eval_episodes

    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    return avg_reward


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--policy", default="TD3")  # Policy name (TD3, DDPG or OurDDPG)
    parser.add_argument("--env", default="Ant-v4")  # OpenAI gym environment name
    parser.add_argument("--seed", default=seed, type=int)  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--start_timesteps", default=25e3, type=int)  # Time steps initial random policy is used
    parser.add_argument("--eval_freq", default=5e3, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=max_timestep, type=int)  # Max time steps to run environment
    parser.add_argument("--expl_noise", default=0.1, type=float)  # Std of Gaussian exploration noise
    parser.add_argument("--batch_size", default=256, type=int)  # Batch size for both actor and critic
    parser.add_argument("--discount", default=0.99, type=float)  # Discount factor
    parser.add_argument("--tau", default=0.005, type=float)  # Target network update rate
    parser.add_argument("--policy_noise", default=0.2)  # Noise added to target policy during critic update
    parser.add_argument("--noise_clip", default=0.5)  # Range to clip target policy noise
    parser.add_argument("--policy_freq", default=2, type=int)  # Frequency of delayed policy updates
    parser.add_argument("--save_model", action="store_true")  # Save model and optimizer parameters
    parser.add_argument("--load_model", default="")  # Model load file name, "" doesn't load, "default" uses file_name
    args = parser.parse_known_args()[0]

    for env_name in envs:
        args.env = env_name
        file_name = f"{args.policy}_{args.env}_{args.seed}"
        file = open(file_name + '.csv', "w")
        print("---------------------------------------")
        print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}")
        print("---------------------------------------")

        if not os.path.exists("./results"):
            os.makedirs("./results")

        if not os.path.exists("./models"):
            os.makedirs("./models")

        env = gym.make(args.env)

        # Set seeds
        env.reset(seed=args.seed)
        env.action_space.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        max_action = float(env.action_space.high[0])

        kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
            "max_action": max_action,
            "discount": args.discount,
            "tau": args.tau,
        }

        # Initialize policy
        if args.policy == "TD3":
            # Target policy smoothing is scaled wrt the action scale
            kwargs["policy_noise"] = args.policy_noise * max_action
            kwargs["noise_clip"] = args.noise_clip * max_action
            kwargs["policy_freq"] = args.policy_freq
            policy = TD3(**kwargs)
        elif args.policy == "OurDDPG":
            policy = OurDDPG.DDPG(**kwargs)
        elif args.policy == "DDPG":
            policy = DDPG.DDPG(**kwargs)

        if args.load_model != "":
            policy_file = file_name if args.load_model == "default" else args.load_model
            policy.load(f"./models/{policy_file}")

        replay_buffer = ReplayBuffer(state_dim, action_dim)

        # Evaluate untrained policy
        evaluations = [eval_policy(policy, args.env, args.seed)]

        state, done = env.reset(), False
        state = state[0]
        episode_reward = 0
        episode_timesteps = 0
        episode_num = 0

        actor_update = 0
        critic_update = 0
        actor_loss = 0
        critic_loss = 0

        for t in range(int(args.max_timesteps)):

            episode_timesteps += 1

            # Select action randomly or according to policy
            if t < args.start_timesteps:
                action = env.action_space.sample()
            else:
                action = (
                        policy.select_action(np.array(state))
                        + np.random.normal(0, max_action * args.expl_noise, size=action_dim)
                ).clip(-max_action, max_action)

            # Perform action
            next_state, reward, done1, done2, _ = env.step(action)
            done = done1 + done2
            done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0

            # Store data in replay buffer
            replay_buffer.add(state, action, next_state, reward, done_bool)

            state = next_state
            episode_reward += reward



            # Train agent after collecting sufficient data
            if t >= args.start_timesteps:
                [loss1, loss2, isactorupdate] = policy.train(replay_buffer, args.batch_size)
                critic_loss += loss1
                critic_update += 1
                if isactorupdate:
                    actor_loss += loss2
                    actor_update +=1

            if done:
                if actor_update != 0:
                    actor_loss /= actor_update
                if critic_update != 0:
                    critic_loss /= critic_update

                # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
                # print(
                #     f"Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f} Q loss:{critic_loss:.5f}  actor Q:{actor_loss:.5f}")
                # Reset environment
                file.write(str(f"{t + 1}\t{episode_num + 1}\t{episode_timesteps}\t{episode_reward:.3f}\t{critic_loss:.5f}\t{actor_loss:.5f}\n"))
                state, done = env.reset(), False
                state = state[0]
                episode_reward = 0
                episode_timesteps = 0
                episode_num += 1

                actor_update = 0
                critic_update = 0
                actor_loss = 0
                critic_loss = 0


            # Evaluate episode


            if (t + 1) % BN_update_freq == 0:
                policy.BN_update()
            if (t + 1) % args.eval_freq == 0:
                evaluations.append(eval_policy(policy, args.env, args.seed))
                np.save(f"./results/{file_name}", evaluations)
        policy.save(f"./models/{file_name}")
                # if args.save_model: policy.save(f"./models/{file_name}")

        file.close()