import sys,os,time
sys.path.append(os.path.abspath(os.path.dirname(__file__) + '/' + '..'))

import torch 
import torch.nn as nn
import gym
import numpy as np
from gym.spaces.box import Box
from torch.distributions.normal import Normal
from torch.nn import functional as F
from types import SimpleNamespace
import algo_envs.algo_base as AlgoBase
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
import libs.utils as utils
import random
from collections import deque
import csv

"""
Environments' objects for training, including
     env_name:  environment name
      obs_dim:  dimention of observation
      act_dim:  dimention of action
     hide_dim:  dimention of hidden network

Example:
    >>> obs_dim = train_envs[current_env_name].obs_dim
    >>> act_dim = train_envs[current_env_name].act_dim
    >>> hide_dim = train_envs[current_env_name].hide_dim
"""


train_envs = {
    'Swimmer':SimpleNamespace(**{'env_name': "Swimmer-v3",'obs_dim':8,'act_dim':2,'hide_dim':64}),
    'HalfCheetah':SimpleNamespace(**{'env_name': "HalfCheetah-v3",'obs_dim':17,'act_dim':6,'hide_dim':64}),
    'Ant':SimpleNamespace(**{'env_name': "Ant-v3",'obs_dim':111,'act_dim':8,'hide_dim':256}),
    'Hopper':SimpleNamespace(**{'env_name': "Hopper-v3",'obs_dim':11,'act_dim':3,'hide_dim':64}),
    'Pusher':SimpleNamespace(**{'env_name': "Pusher-v2",'obs_dim':23,'act_dim':7,'hide_dim':64}),
    'Humanoid':SimpleNamespace(**{'env_name': "Humanoid-v3",'obs_dim':376,'act_dim':17,'hide_dim':512}),
    'Walker2d':SimpleNamespace(**{'env_name': "Walker2d-v3",'obs_dim':17,'act_dim':6,'hide_dim':64}),
}

# current environment name
current_env_name = 'HalfCheetah'

#training parameters
train_config = dict()

# gae lambda
train_config['gae_lambda'] = 0.95 

# discount factor
train_config['gamma'] = 0.99 

# policy gradient loss clip
train_config['clip_coef'] = 0.2 

# ratio upper bound
train_config['max_clip_coef'] = 4 

# weight of value loss
train_config['vf_coef'] = 1 

# weight of entropy loss
train_config['ent_coef'] = 0.0

# learning rate
train_config['learning_rate'] = 2.5e-4 

# weight of ratio
train_config['ratio_coef'] = 0.5 

# max gradient normal of value
train_config['grad_norm'] = 0.5 

#max version
train_config['max_version'] = int(1e6)

# policy gradient type: 0 is discrete, 1 is continues, 2 is mixed and 3 is mixed policy gradient loss
train_config['pg_loss_type'] = 1 

# whether clip value loss
train_config['enable_clip_max'] = True 

#noise net
train_config['use_noise'] = True 

# whether to decay ratio 
train_config['enable_ratio_decay'] = False 

# whether to decay entropy coefficient
train_config['enable_entropy_decay'] = False 

# whether to decay learning rate
train_config['enable_lr_decay'] = False 

# whether to clip gradient norm
train_config['enable_grad_norm'] = False 

# whether to normalize advantage function
train_config['enable_adv_norm'] = True 

# whether to use mini batch
train_config['enable_mini_batch'] = True 

# the length of generalized advantage estimator for calculating 
train_config['gae_length'] = 256  

# the num of environments
train_config['num_envs'] = ((16*256)//train_config['gae_length'])

# an episode length
train_config['num_steps'] = 512 

# whether to use GPU to training
train_config['use_gpu'] = False 

# minibatch size, which should divisible by an epdsode length
train_config['mini_batch_size'] = 512 

# for tensorboard naming
train_config['tensorboard_comment'] = 'gae_length_'+str(train_config['gae_length'])+'_num_steps_'+str(train_config['num_steps'])+'_adv'

class PPOMujocoNormalShareGAENet(AlgoBase.AlgoBaseNet):
    """

    Example:
        >>> current_env_name = 'Ant' # Mujoco environment
        >>> train_net = PPOMujocoNormalShareGAENet()
        >>> states = torch.randn(64,111) # minibatch = 64, state dimention = 111
        >>> actions = torch.rand(64,8) # action dimention = 8 
        >>> train_net.get_distris(states) # get policy distributions
        Normal(loc: torch.Size([64,8]), scale: torch.Size([64,8]))
        
        >>> values, log_probs, distris_entropy = train_net(states,actions) # return state values, log probilities of actions and distribution entropy of actions
        >>> print(values.size(), log_probs.size(), distris_entropy.size())
        torch.Size([64,1]) torch.Size([64,8]) torch.Size([64,8])
    """

    def __init__(self):
        super(PPOMujocoNormalShareGAENet,self).__init__()
        
        obs_dim = train_envs[current_env_name].obs_dim
        act_dim = train_envs[current_env_name].act_dim
        hide_dim = train_envs[current_env_name].hide_dim
        
        if train_config['use_noise']:
            self.noise_layer_out = AlgoBase.NoisyLinear(hide_dim,act_dim)
            self.noise_layer_hide = AlgoBase.NoisyLinear(hide_dim,hide_dim)
                            
            #normal mu
            self.mu = nn.Sequential(
                    AlgoBase.layer_init(nn.Linear(obs_dim, hide_dim)),
                    nn.ReLU(),
                    AlgoBase.layer_init(nn.Linear(hide_dim, hide_dim)),
                    nn.ReLU(),
                    self.noise_layer_hide,
                    nn.ReLU(),
                    self.noise_layer_out,
                    nn.Tanh()
                )
        else:
            #normal mu
            self.mu = nn.Sequential(
                    AlgoBase.layer_init(nn.Linear(obs_dim, hide_dim)),
                    nn.ReLU(),
                    AlgoBase.layer_init(nn.Linear(hide_dim, hide_dim)),
                    nn.ReLU(),
                    AlgoBase.layer_init(nn.Linear(hide_dim, hide_dim)),
                    nn.ReLU(),
                    AlgoBase.layer_init(nn.Linear(hide_dim, act_dim)),
                    nn.Tanh()
                )
                
        log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
        self.log_std = nn.Parameter(torch.as_tensor(log_std))
        
        self.value = nn.Sequential(
                AlgoBase.layer_init(nn.Linear(obs_dim, hide_dim)),
                nn.ReLU(),
                AlgoBase.layer_init(nn.Linear(hide_dim, hide_dim)),
                nn.ReLU(),
                AlgoBase.layer_init(nn.Linear(hide_dim, 1))
            )
                
    def get_distris(self,states):
        """
        Calculate the distributions of states
        
        Args:
            states

        Return:
            distribution of states

        Example:
            >>> current_env_name = 'Ant' 
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> states = torch.randn(64,111) 
            >>> train_net.get_distris(states) 
            Normal(loc: torch.Size([64,8]), scale: torch.Size([64,8]))
        """

        mus = self.mu(states)
        distris = Normal(mus,torch.exp(self.log_std))
        return distris

    def get_value(self,states):
        """
        Calculate state values

        Args:
            states

        Return:
            state-values
        
        Example:
            >>> current_env_name = 'Ant' 
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> states = torch.randn(64,111) 
            >>> train_net.get_value(states) 
        """
        return self.value(states)
        
    def forward(self,states,actions):
        """
        Calculate state values, log probilities of each action and distribution entropy of each action

        Args:
            states
            actions

        Return:
            state values, log probilities of each action and distribution entropy of each action

        Example:
            >>> current_env_name = 'Ant' # Mujoco environment
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> states = torch.randn(64,111) 
            >>> actions = torch.rand(64,8) 
            >>> values, log_probs, distris_entropy = train_net(states,actions) 
            >>> print(values.size(), log_probs.size(), distris_entropy.size())
            torch.Size([64,1]) torch.Size([64,8]) torch.Size([64,8])
        """
        # mus = self.mu(states)
        # return mus
        values = self.value(states)
        distris = self.get_distris(states)
        log_probs = distris.log_prob(actions) 
        return values,log_probs,distris.entropy()
    
    def get_sample_data(self,states):
        """
        Return actions and log probilities of each action

        Args:
            states

        Example:
            >>> current_env_name = 'Ant' 
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> states = torch.randn(64,111)
            >>> actions, log_probs = train_net.get_sample_data(states) 
        """

        distris = self.get_distris(states)
        actions = distris.sample()
        log_probs = distris.log_prob(actions)
        return actions,log_probs
    
    def get_check_data(self,states):
        """
        Return expectations of states, entropy of state distributions and log probilities of taking the best actions

        Args:
            states

        Example:
            >>> current_env_name = 'Ant' 
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> states = torch.randn(64,111)
            >>> mus, entropy, log_probs = train_net.get_check_data(states) 
        """
        distris = self.get_distris(states)
        mus = self.mu(states)
        log_probs = distris.log_prob(distris.mean)
        value = self.get_value(states)
        return mus,distris.entropy(),log_probs,value
    
    def get_calculate_data(self,states,actions):
        values = self.value(states)
        distris = self.get_distris(states)
        log_probs = distris.log_prob(actions) 
        return values,log_probs,distris.entropy()
    
    def sample_noise(self):
        """Add normal noise to network parameter, more details see in NoisyLinear class 

        Example:
            >>> current_env_name = 'Ant' 
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> train_net.sample_noise()
        """
        if train_config['use_noise']:
            self.noise_layer_out.sample_noise()
            self.noise_layer_hide.sample_noise()
    
class PPOMujocoNormalShareGAEUtils(AlgoBase.AlgoBaseUtils):
    pass
                    
class PPOMujocoNormalShareGAEAgent(AlgoBase.AlgoBaseAgent):
    """
    Agent class used with continues PPO, allowing collect data and evaluate agents.

    Args:
        sample_net: policy network (default: PPOMujocoNormalShareGAENet)
        model_dict: a dict of model configuration
        is_checker: if "True", then evaluating the agent through running 1024 timesteps with 
                    the highest probility of action, else collecting the training data.

    Example:
        >>> train_net = PPOMujocoNormalShareGAENet()
        >>> # Collecting training data
        >>> sample_agent = PPOMujocoNormalShareGAEAgent(train_net,model_dict,is_checker=False)
        >>> transition = sample_agent.sample_env()
        >>> # Evaluating agent 
        >>> check_agent = PPOMujocoNormalShareGAEAgent(train_net,model_dict,is_checker=True)
        >>> info = check_agent.check_env()
        >>> print(info['sum_rewards'], info['mean_entropys'], info['mean_mus'], info['mean_log_probs'])

    """
    
    def __init__(self,sample_net:PPOMujocoNormalShareGAENet,model_dict,is_checker):
        super(PPOMujocoNormalShareGAEAgent,self).__init__()
        self.sample_net = sample_net
        self.model_dict = model_dict
        self.num_steps = train_config['num_steps']
        self.num_envs = train_config['num_envs']
        self.gae_length = train_config['gae_length']
        self.step = 0
        self.rewards_ = deque(maxlen=50000)
        self.log_probs_ = deque(maxlen=50000)
        
        env_name = train_envs[current_env_name].env_name
    
        if not is_checker:
            self.envs = [gym.make(env_name) for _ in range(self.num_envs)]
            self.states = [self.envs[i].reset() for i in range(self.num_envs)]
            self.exps_list = [[] for _ in range(self.num_envs)]
        else:
            print("PPOMujocoNormalShare check mujoco env is",env_name)
            self.envs = gym.make(env_name)
            self.states = self.envs.reset()
            
    def get_comment_info(self):
        return current_env_name + "_" + train_config['tensorboard_comment']
        
    def sample_env(self):
        """collect training data 
        Example:
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> sample_agent = PPOMujocoNormalShareGAEAgent(train_net,model_dict,is_checker=False)
            >>> transition = sample_agent.sample_env()
        """

        step_record_dict = dict()
        while len(self.exps_list[0]) < self.num_steps:
            actions,log_probs = self.get_sample_actions(self.states)
            for i in range(self.num_envs):
                next_state_n, reward_n, done_n, _ = self.envs[i].step(actions[i])                
                if done_n:
                    next_state_n = self.envs[i].reset()
                    
                self.exps_list[i].append([self.states[i],actions[i],reward_n,done_n,log_probs[i],self.model_dict['train_version']])
                self.states[i] = next_state_n
                self.rewards_.append(reward_n)
                self.log_probs_.append(log_probs)
        # Starting training
        train_exps = self.exps_list
        # Deleting the length before self.gae_length
        self.exps_list = [ exps[self.gae_length:self.num_steps] for exps in self.exps_list ]

        step_record_dict['sum_rewards'] = np.mean(self.rewards_)*1000
        step_record_dict['mean_log_probs'] = np.mean(self.log_probs_)

        return train_exps,step_record_dict
    
    def check_env(self):
        """evaluate agent
        Example:
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> check_agent = PPOMujocoNormalShareGAEAgent(train_net,model_dict,is_checker=True)
            >>> info = check_agent.check_env()
        """
        step_record_dict = dict()
        
        mus = []
        rewards = []
        entropys = []
        log_probs = []
        values = []
        dones = []
        check_length = 1000
        for episode_step in range(check_length):
            #self.envs.render()
            mu,entropy,log_prob,value = self.get_check_action(self.states)
            next_state_n, reward_n, is_done, _ = self.envs.step(mu)
            if is_done:
                next_state_n = self.envs.reset()
            self.states = next_state_n
            rewards.append(reward_n)
            mus.append(mu)
            entropys.append(entropy)
            log_probs.append(log_prob)
            dones.append(is_done)
            values.append(value[0])

        values = np.array(values).tolist()
        #print(str(values))
        step_record_dict['sum_rewards'] = np.sum(rewards)
        step_record_dict['mean_entropys'] = np.mean(entropys)
        step_record_dict['mean_mus'] = np.mean(mus)
        step_record_dict['mean_log_probs'] = np.mean(log_probs)
        
        gamma = train_config['gamma']
        gae_lambda = train_config['gae_lambda']
        last_gae_lam = 0
        last_ar = 0
        last_av = 0
        std_length = 500
        advantages = np.zeros((std_length,))
        gaers = np.zeros((std_length,))
        gaevs = np.zeros((std_length,))
        for t in reversed(range(std_length)):
            if (t == std_length-1) or dones[t]:
                delta_v = - values[t]
                delta_r = rewards[t]
                delta = delta_v+delta_r
                advantages[t] = last_gae_lam = delta
                gaers[t] = last_ar = delta_r
                gaevs[t] = last_av = delta_v
            else:
                next_values = values[t + 1]
                delta_v = gamma * next_values - values[t]
                delta_r = rewards[t]
                delta = delta_v+delta_r
                advantages[t] = last_gae_lam = delta + gamma * gae_lambda * last_gae_lam
                gaers[t] = last_ar = delta_r + gamma * gae_lambda * last_ar
                gaevs[t] = last_av = delta_v + gamma * gae_lambda * last_av
        step_record_dict['advantages'] = np.array(advantages).tolist()
        step_record_dict['gaers'] = np.array(gaers).tolist()
        step_record_dict['gaevs'] = np.array(gaevs).tolist()
        return step_record_dict
            
    @torch.no_grad()
    def get_sample_actions(self,states):
        """Sample actions and calculate action probilities of action

        Args:
            states

        Returns:
            actions
            log_probs

        Example:
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> sample_agent = PPOMujocoNormalShareGAEAgent(train_net,model_dict,is_checker=False)
            >>> states = torch.randn(64,111) 
            >>> actions, log_probs = sample_agent.get_sample_actions(states)
        """
        states_v = torch.Tensor(np.array(states))
        actions,log_probs = self.sample_net.get_sample_data(states_v)
        return actions.cpu().numpy(),log_probs.cpu().numpy()
    
    @torch.no_grad()
    def get_check_action(self,state):
        """Get the highest probility of action, and it's entropy and log probility 
        Example:
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> check_agent = PPOMujocoNormalShareGAEAgent(train_net,model_dict,is_checker=True)
            >>> states = torch.randn(111) 
            >>> mu, entropy, log_prob = check_agent.get_check_actions(state)
        """

        state_v = torch.Tensor(np.array(state))
        mu,entropy,log_prob,value = self.sample_net.get_check_data(state_v)
        return mu.cpu().numpy(),entropy.cpu().numpy(),log_prob.cpu().numpy(),value.cpu().numpy()

    def close(self):
        for env in self.envs:
            env.close()
        print('agent closed')
    
    def check_close(self):
        self.envs.close()
        print('check closed')
            
class PPOMujocoNormalShareGAECalculate(AlgoBase.AlgoBaseCalculate):

    """
    training class used with continues PPO

    Args:
        share_model: policy network (default: PPOMujocoNormalShareGAENet)
        model_dict: a dict of model configuration
        calculate_index: the :math:`calculate_index`th agent for training

    
    Example:
        >>> train_net = PPOMujocoNormalShareGAENet()
        >>> calculate = PPOMujocoNormalShareGAECalculate(train_net,model_dict,calculate_index)
        >>> # samples are from transitions
        >>> calculate.begin_batch_train(samples)
        >>> for _ in range(REPEAT_TIMES):
        >>>    calculate.generate_grads()
        >>> calculate.end_batch_train()

    """
    
    def __init__(self,share_model:PPOMujocoNormalShareGAENet,model_dict,calculate_index):
        super(PPOMujocoNormalShareGAECalculate,self).__init__()
        self.model_dict = model_dict
        self.share_model = share_model
        self.gae_length = train_config['gae_length']
        self.calculate_number = self.model_dict['num_trainer']
        self.calculate_index = calculate_index
        self.train_version = 0
        
        if train_config['use_gpu'] and torch.cuda.is_available():
            device_count = torch.cuda.device_count()
            device_index = self.calculate_index % device_count
            self.device = torch.device('cuda',device_index)
        else:
            self.device = torch.device('cpu')
                        
        self.calculate_net = PPOMujocoNormalShareGAENet()
        self.calculate_net.to(self.device)
        #self.calculate_net.load_state_dict(self.share_model.state_dict())
    
        self.share_optim = torch.optim.Adam(params=self.share_model.parameters(), lr=train_config['learning_rate'])
        
        # Clear training data
        self.states_list = None
        self.actions_list = None
        self.rewards_list = None
        self.dones_list = None
        self.old_log_probs_list = None
        self.advantage_list = None
        self.returns_list = None
        
    def begin_batch_train(self, samples_list: list): 
        """store training data
        Example:
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> calculate = PPOMujocoNormalShareGAECalculate(train_net,model_dict,calculate_index)
            >>> calculate.begin_batch_train(samples)
        """  

        s_states = [np.array([s[0] for s in samples]) for samples in samples_list]
        s_actions = [np.array([s[1] for s in samples]) for samples in samples_list]
        s_log_probs = [np.array([s[4] for s in samples]) for samples in samples_list]
        
        s_rewards = [np.array([s[2] for s in samples]) for samples in samples_list]
        s_dones = [np.array([s[3] for s in samples]) for samples in samples_list]
        
        #s_versions = [s[6] for s in samples]
        
        self.states = [torch.Tensor(states).to(self.device) for states in s_states]
        self.actions = [torch.Tensor(actions).to(self.device) for actions in s_actions]
        self.old_log_probs = [torch.Tensor(log_probs).to(self.device) for log_probs in s_log_probs]
        self.rewards = s_rewards
        self.dones = s_dones
        
        self.states_list = torch.cat([states[0:self.gae_length] for states in self.states])
        self.actions_list = torch.cat([actions[0:self.gae_length] for actions in self.actions])
        self.old_log_probs_list = torch.cat([old_log_probs[0:self.gae_length] for old_log_probs in self.old_log_probs])
        
    def end_batch_train(self):
        """clear training data, update learning rate and noisy network
        Example:
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> calculate = PPOMujocoNormalShareGAECalculate(train_net,model_dict,calculate_index)
            >>> calculate.end_batch_train(samples)
        """
        
        self.states_list = None
        self.actions_list = None
        self.rewards_list = None
        self.dones_list = None
        self.old_log_probs_list = None
        self.advantage_list = None
        self.returns_list = None
        
        train_version = self.model_dict[self.calculate_index]
        self.decay_lr(train_version)
        
        # Resetting sample noise
        if self.calculate_index == self.calculate_number - 1:
            self.share_model.sample_noise()

    def calculate_samples_gae(self):
        """Calculate samples generalized advantage estimator, more details see in algo_base.calculate_gae
        """
        gamma = train_config['gamma']
        gae_lambda = train_config['gae_lambda']
        
        np_advantages = []
        np_returns = []
        
        for states,rewards,dones in zip(self.states,self.rewards,self.dones):
            with torch.no_grad():
                values = self.calculate_net.get_value(states)
                            
            advantages,returns = AlgoBase.calculate_gae(values.cpu().numpy().reshape(-1),rewards,dones,gamma,gae_lambda)
            np_advantages.extend(advantages[0:self.gae_length])
            np_returns.extend(returns[0:self.gae_length])
            
        return np.array(np_advantages), np.array(np_returns)
            
    def decay_lr(self, version):
        """decrease learning rate:
        :math:`lr = lr(1- ve / max_ve )`
        where :math:`lr` is learning rate, :math:`ve` is current version and :math:`max_ve` is the highest version.
        Minimum learning rate is equal to 1e-6

        Example:
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> calculate = PPOMujocoNormalShareGAECalculate(train_net,model_dict,calculate_index)
            >>> calculate.decay_lr(calculate_index)
            
        """

        if train_config['enable_lr_decay']:
            lr_now = train_config['learning_rate'] * (1 - version*1.0 / train_config['max_version'])
            if lr_now <= 1e-6:
                lr_now = 1e-6
            
            if self.share_optim is not None:
                for param in self.share_optim.param_groups:
                    param['lr'] = lr_now
                                                                                   
    def generate_grads(self): 
        """ update share network parameters. 
        
        If action is discrete, then :math:`ratio1 = exp(new_log_probs - old_log_probs)`, if action is continues, then
        :math:`ratio2 = \prod{ratio1}` and expand to the same dimention as :math:`ratio1`, if action is mixed, then 
        :math:`ratio3 = ratio1 * ratio_coef + ratio2 * (1.0 - ratio_coef)`, where :math:`ratio_coef` is weight coefficent.
        

        Example:
            >>> train_net = PPOMujocoNormalShareGAENet()
            >>> calculate = PPOMujocoNormalShareGAECalculate(train_net,model_dict,calculate_index)
            >>> calculate.begin_batch_train(samples)
            >>> REPEAT_TIMES = 10
            >>> for _ in range(REPEAT_TIMES):
            >>>     calculate.generate_grads()
            >>> calculate.end_batch_train()
        """ 
        train_version = self.model_dict[self.calculate_index]
        vf_coef = train_config['vf_coef']
        pg_loss_type = train_config['pg_loss_type']
        grad_norm = train_config['grad_norm']
        mini_batch_size = train_config['mini_batch_size']

        ent_coef = train_config['ent_coef']
        ratio_coef = self.get_ratio_coef(train_version)
    
        self.calculate_net.load_state_dict(self.share_model.state_dict())
                        
        #start = timer()
        np_advantages,np_returns = self.calculate_samples_gae()
        
        #run_time = timer() - start
        #print("CPU function took %f seconds." % run_time)
        #renp_advantages,renp_returns =np_advantages,np_returns
        if train_config['enable_adv_norm']:
            np_advantages = (np_advantages - np_advantages.mean()) / (np_advantages.std() + 1e-8)
                                                    
        advantage_list = torch.Tensor(np_advantages.reshape(-1,1)).to(self.device)    
        returns_list = torch.Tensor(np_returns.reshape(-1,1)).to(self.device)
        
        if train_config['enable_mini_batch']:
            mini_batch_number = advantage_list.shape[0] // mini_batch_size
        else:
            mini_batch_number = 1
            mini_batch_size = advantage_list.shape[0]

        for i in range(mini_batch_number):
            start_index = i*mini_batch_size
            end_index = (i+1)* mini_batch_size
            
            mini_states = self.states_list[start_index:end_index]
            mini_actions = self.actions_list[start_index:end_index]
            mini_old_log_probs = self.old_log_probs_list[start_index:end_index]
            
            self.calculate_net.load_state_dict(self.share_model.state_dict())
                
            mini_new_values,mini_new_log_probs,mini_entropys = self.calculate_net(mini_states,mini_actions)
            
            mini_advantage = advantage_list[start_index:end_index]
            mini_returns = returns_list[start_index:end_index]

            #discrete ratio
            ratio1 = torch.exp(mini_new_log_probs-mini_old_log_probs)

            #prod ratio
            #ratio2 = torch.exp(t_new_log_probs.sum(1) - old_log_probs.sum(1)).reshape(-1,1).expand_as(ratio1)
            
            ratio2 = ratio1.prod(1,keepdim=True).expand_as(ratio1)
            #ratio2 = AlgoBase.GradCoef.apply(ratio2,1.0/ratio2.shape[1])
            
            #ratio2 = self.get_prod_ratio(ratio1)
            
            #mixed ratio
            #ratio3 = (AlgoBase.GradCoef.apply(ratio1,ratio_coef) + AlgoBase.GradCoef.apply(ratio2, 2.0 - ratio_coef)) / 2
            ratio3 = ratio1 * ratio_coef + ratio2 * (1.0 - ratio_coef)

            #discrete
            if pg_loss_type == 0:
                pg_loss = self.get_pg_loss(ratio1,mini_advantage)
                
            #prod
            elif pg_loss_type == 1:
                pg_loss = self.get_pg_loss(ratio2,mini_advantage)
                
            #mixed
            elif pg_loss_type == 2:
                pg_loss = self.get_pg_loss(ratio3,mini_advantage)
                
            #last_mixed
            elif pg_loss_type == 3:
                pg_loss1 = self.get_pg_loss(ratio1,mini_advantage)
                pg_loss2 = self.get_pg_loss(ratio2,mini_advantage)
                pg_loss = (pg_loss1+pg_loss2)/2
                            
            # Policy loss
            pg_loss = -torch.mean(pg_loss)
            
            v_loss = F.mse_loss(mini_returns, mini_new_values) * vf_coef
            
            e_loss = -torch.mean(mini_entropys) * ent_coef
            
            loss = pg_loss + v_loss + e_loss
            
            self.calculate_net.zero_grad()

            loss.backward()
            
            grads = [
                param.grad.data.cpu().numpy()
                if param.grad is not None else None
                for param in self.calculate_net.parameters()
            ]
            
            # Updating network parameters
            for param, grad in zip(self.share_model.parameters(), grads):
                param.grad = torch.FloatTensor(grad)

            if train_config['enable_grad_norm']:
                torch.nn.utils.clip_grad_norm_(self.share_model.parameters(),grad_norm)  
            self.share_optim.step()
        #return np.std(renp_advantages),np.std(renp_returns)
    
    def get_pg_loss(self,ratio,advantage):
        """Calculate policy gradient loss
        If :math:`enable_clip_max` is false, then ratio between :math:`1 - clip_coef` to :math:`1 + clip_coef`, otherwise is equal to 0
        else ratio between :math:`1- clip_coef` to :math:`min( 1 + clip_coef, max_clip_coef)`, otherwise is equal to 0
        """

        clip_coef = train_config['clip_coef']
        max_clip_coef = train_config['max_clip_coef']
        enable_clip_max = train_config['enable_clip_max']
        
        # base_value = ratio * advantage
        # clip_value = torch.clamp(ratio,1.0 - clip_coef,1.0 + clip_coef) * advantage
        # min_loss_policy = torch.min(base_value, clip_value)        
        # max_loss_policy = torch.max(min_loss_policy,max_clip_coef * advantage)
        
        # return torch.where(advantage>=0,min_loss_policy,max_loss_policy)
        
        positive = torch.where(ratio >= 1.0 + clip_coef, 0 * advantage,advantage)
        if enable_clip_max:
            negtive = torch.where(ratio <= 1.0 - clip_coef,0 * advantage,torch.where(ratio >= max_clip_coef, 0 * advantage,advantage))
        else:
            negtive = torch.where(ratio <= 1.0 - clip_coef,0 * advantage,advantage)
        
        return torch.where(advantage>=0,positive,negtive)*ratio
    
    def get_ent_coef(self,version):
        """decrease entropy coefficient:
        :math:`ef = lr(1- ve / max_ve )`
        where :math:`ef` is entropy coefficient, :math:`ve` is current version and :math:`max_ve` is the highest version.
        Minimum learning rate is equal to 1e-8
        """

        if train_config['enable_entropy_decay']:
            ent_coef = train_config['ent_coef'] * (1 - version*1.0 / train_config['max_version'])
            if ent_coef <= 1e-8:
                ent_coef = 1e-8
            return ent_coef
        else:
            return train_config['ent_coef']

    def get_ratio_coef(self,version):
        """increase ratio from 0 to 0.95 in mixed environment"""
        if train_config['enable_ratio_decay']:
            ratio_coef = version/train_config['max_version']
            if ratio_coef >= 1.0:
                ratio_coef = 0.95       
            return ratio_coef   
        
        else:
            return train_config['ratio_coef']
        
if __name__ == "__main__":
    note_tittle = 'record/'+current_env_name+time.strftime('%Y%m%d%H%M%S',time.localtime())+'.txt'
    Note=open(note_tittle,mode='w')
    #gae_length_list = [64,128,256,512]
    for repeat in range(1):
        #train_config['gae_length'] = gae_length
        #train_config['num_envs'] = ((16*256)//train_config['gae_length'])
        #train_config['tensorboard_comment'] = 'gae_length_'+str(train_config['gae_length'])+'_num_steps_'+str(train_config['num_steps'])+'_adv'
    
        seed = random.randint(0,65536)
        utils.setup_seed(seed)
        # used for tensorboard naming
        comment = "_PPOMujoco_" + current_env_name + "_" + train_config['tensorboard_comment'] + "_" + str(seed)
        writer = SummaryWriter(comment=comment)

        # initialize training network
        train_net = PPOMujocoNormalShareGAENet()
        model_dict = {}
        model_dict[0] = 0
        model_dict['num_trainer'] = 1
        model_dict['train_version'] = 0

        # initialize a RL agent, smaple agent used for sampling training data, check agent used for evaluating 
        # and calculate used for calculating gradients
        sample_agent = PPOMujocoNormalShareGAEAgent(train_net,model_dict,is_checker=False)
        check_agent = PPOMujocoNormalShareGAEAgent(train_net,model_dict,is_checker=True)
        calculate = PPOMujocoNormalShareGAECalculate(train_net,model_dict,0)

        # hyperparameters
        MAX_VERSION = 1000
        REPEAT_TIMES = 10
        for version in range(MAX_VERSION):
            
            samples_list,infos = sample_agent.sample_env()
            samples = []
            
            for s in samples_list:
                samples.append(s)

            calculate.begin_batch_train(samples)
            std_davs = []
            std_rets = []
            for _ in range(REPEAT_TIMES):
                calculate.generate_grads()

            calculate.end_batch_train()
                 
            model_dict[0] = model_dict[0] + 1
            model_dict['train_version'] = model_dict[0]
            
            for (key,value) in  infos.items():
                writer.add_scalar(key, value, model_dict[0])
            print('--'*(version//25+1),MAX_VERSION-version,end='\r')
            if (version+1)%250 == 0:
                check_reward = []
                all_advs = []
                all_rs = []
                all_vs = []
                check_time = 0
                for _ in range(2000):
                    check_infos = check_agent.check_env()
                    check_reward.append(check_infos['sum_rewards'])
                    all_advs.append(check_infos['advantages'])
                    all_rs.append(check_infos['gaers'])
                    all_vs.append(check_infos['gaevs'])

                Note.write("check_reward:"+str(check_reward)+' \n')
                Note.write("average_reward:"+str(np.mean(check_reward))+' \n')

                with open('csv/'+current_env_name+'_gae_'+str(version+1)+'.csv', mode='w', newline='') as AdvCsv:
                    csv_writer = csv.writer(AdvCsv)
                    for row in range(len(all_advs)):
                        csv_writer.writerow(all_advs[row])
                AdvCsv.close()

                with open('csv/'+current_env_name+'_rewards_'+str(version+1)+'.csv', mode='w', newline='') as RCsv:
                    csv_writer = csv.writer(RCsv)
                    for row in range(len(all_rs)):
                        csv_writer.writerow(all_rs[row])
                RCsv.close()

                with open('csv/'+current_env_name+'_values_'+str(version+1)+'.csv', mode='w', newline='') as VCsv:
                    csv_writer = csv.writer(VCsv)
                    for row in range(len(all_vs)):
                        csv_writer.writerow(all_vs[row])
                VCsv.close()
        
        writer.close()
        sample_agent.close()
        check_agent.check_close()
            
    Note.close()           
