import numpy as np
import torch
import gym
import argparse
import os
# import gym_platform
from Raw_RL import utils
# from agents import TD3
from agents import P_TD3
# from agents import OurDDPG
from agents import P_DDPG
from common import ClickPythonLiteralOption
# from common.platform_domain import PlatformFlattenedActionWrapper
from common.wrappers import ScaledStateWrapper, ScaledParameterisedActionWrapper
import matplotlib.pyplot as plt
# from common.goal_domain import GoalFlattenedActionWrapper, GoalObservationWrapper
import math
import time

import models.networks as nets
# import wandb


# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')


def save_points(args):
    run = wandb.init(
        project="pamdp-mpc",
        config=args,
        dir="../scratch/wandb"
    )


class ModelNets(torch.nn.Module):
    def __init__(self, args):
        super().__init__()

        self.s_dim = args.state_dim
        self.inp_dim = args.state_dim + args.k_dim + args.z_dim

        self._dyanmics = nets.TanhGaussianPolicy(
            hidden_sizes=[args.layers for _ in range(3)], 
            oup_dim=self.s_dim, 
            inp_dim=self.inp_dim,
            tanh=False).to(device)
        self._reward = nets.TanhGaussianPolicy(
            hidden_sizes=[args.layers for _ in range(3)], 
            oup_dim=1, 
            inp_dim=self.inp_dim,
            tanh=False).to(device)
        self._continue = nets.TanhGaussianPolicy(
            hidden_sizes=[args.layers for _ in range(3)], 
            oup_dim=2, 
            inp_dim=self.inp_dim,
            tanh=False).to(device)
        
    def next(self, s, a, reparameterize, return_log_prob, deterministic):
        '''
                        reparameterize  return_log_prob deterministic
            train           True            True            False
            train_plan      False           True            False    
            evaluate_plan   -               -               True
            estimate_value  -               -               True
        '''
        x = torch.cat([s, a], dim=-1)
        # print(x.shape, s.shape, a.shape)
        
        s = self._dyanmics(x, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)[0]

        r = self._reward(x, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)[0]

        c = self._continue(x, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)[0]
        # c = c.argmax(-1).unsqueeze(-1)  # 0: terminate, 1:continue

        return s, r, c



class MPC:
    def __init__(self, args):
        args.seed_steps = 2_000
        args.use_policy = 0
        args.use_model = 1
        args.mpc_horizon = 1
        args.mpc_popsize = 1_000
        args.embed = 0
        args.cem_iter = 6
        args.mpc_gamma = 0.99
        args.mpc_num_elites = 400
        args.mpc_temperature = 0.5
        args.mpc_alpha = 0.2

        self.args = args
        args.layers = 128

        self.s_dim = args.state_dim
        self.inp_dim = args.state_dim + args.k_dim + args.z_dim

        self.model = ModelNets(args)
        
        self.mse_loss = torch.nn.MSELoss()
        self.ce_loss  = torch.nn.CrossEntropyLoss()

        self.optim = torch.optim.Adam(self.model.parameters(), lr=3e-4)
    
    @torch.no_grad()
    def sample_from_N(self, mean, std):
        if self.args.embed:
            raise "not implemented embed yet"
        else:
            kmean = mean['k']
            zmean, zstd = mean['z'], std

            k_int = torch.multinomial(kmean, self.args.mpc_popsize, replacement=True)
            k_onehot = torch.nn.functional.one_hot(k_int, num_classes=self.args.k_dim).to(device)
            
            z_all = torch.clamp(zmean.unsqueeze(1) + zstd.unsqueeze(1) * \
                    torch.randn(self.args.mpc_horizon, self.args.mpc_popsize, self.args.all_z_dim, device=zstd.device), self.args.lb, self.args.ub)
            
            offsets = torch.tensor(self.args.offset).to(device)[k_int.flatten()].unsqueeze(-1).repeat(1, self.args.z_dim) + torch.arange(self.args.z_dim, device=device)

            z_one = torch.zeros([self.args.mpc_horizon*self.args.mpc_popsize, self.args.all_z_dim+self.args.z_dim], device=device)
            z_one[:, :self.args.all_z_dim] = z_all.reshape([-1, self.args.all_z_dim])
            
            zs = torch.gather(z_one, 1, offsets)
            
            size = torch.from_numpy(self.args.par_size).to(device)[k_int.flatten()].unsqueeze(-1).repeat(1, self.args.z_dim)
            mask = torch.arange(self.args.z_dim).to(device).repeat(len(size), 1)
            mask = torch.where(mask<size, 1., 0.)
            zs = zs * mask
            
            zs = zs.reshape([self.args.mpc_horizon, self.args.mpc_popsize, self.args.z_dim])
            return torch.cat([k_onehot, zs], dim=-1)
    
    @torch.no_grad()
    def estimate_value(self, s, actions, horizon, local_step):
        """Estimate value of a trajectory starting at latent state z and executing given actions."""
        G, discount = 0, 1
        num_traj = s.shape[0]
        c = torch.ones([num_traj, 1], device=device)
        G0 = 0

        for t in range(horizon):
            s_pred, reward, ci = self.model.next(s, actions[t], reparameterize=False, return_log_prob=False, deterministic=True)
            # s_pred, reward, ci = self.model.next(s, actions[t], reparameterize=True, return_log_prob=False, deterministic=False)

            ci = ci.argmax(-1).unsqueeze(-1)

            # if True:
            #     # if local_step + t >= 25:
            #     #     break
            #     # # print(s.shape)
            #     G0 += discount * reward * c

            #     episilon = 1e-4
            #     dist2 = torch.sum(torch.square(s_pred[:, 2:4]), dim=-1)
            #     hard_r = torch.where(s[:, 4]-5/6>-episilon, 0., - dist2)

            #     k = actions[t][:, :2].argmax(-1)
            #     condition = torch.logical_and((s[:, 4])-5/6<episilon, k.flatten())
            #     condition = torch.logical_and(condition, dist2-0.04<episilon)

            #     hard_r = torch.where(condition, 20-dist2, hard_r)
            #     hard_r = hard_r.unsqueeze(-1)

            #     hard_c = (1 - torch.logical_or((hard_r > 4), hard_r.abs()<episilon).long())
                
            #     reward = hard_r
            #     ci = hard_c

            G += discount * reward * c
            discount *= self.args.mpc_gamma
            c *= ci

            s = s_pred

        # print(torch.topk(G.squeeze(1), 10, dim=0).indices)
        # print(torch.topk(G0.squeeze(1), 10, dim=0).indices)
        # exit()
        
        return G

    def rand_action(self):
        k = np.random.randint(low=0, high=self.args.k_dim)
        if k:
            z = 0
        else:
            z = np.random.random() * 2 - 1
        return k, z

    def select_action(self, state, step, local_step, eval_mode=False, t0=False):
        if step < self.args.seed_steps and not eval_mode:
            return self.rand_action()
        
        if eval_mode:
            reparameterize = False
            return_log_prob = False
            deterministic = True
        else:
            reparameterize = False
            return_log_prob = True
            deterministic = False

            # self.timestep += 1
        
        # Sample policy trajectories
        state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        # horizon = int(min(self.args.horizon, h.linear_schedule(self.args.horizon_schedule, step)))

        if self.args.use_policy and (not self.args.use_model):
            k, z = self.pi(state, self.args.min_std, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)
            k, z = k.flatten(), z.flatten()
            k = k.argmax().item()
            z = z[:self.args.par_size[k]]
            return k, z
        
        horizon = self.args.mpc_horizon
        num_pi_trajs = int(self.args.mixture_coef * self.args.mpc_popsize) if self.args.use_policy else 0

        if num_pi_trajs > 0:
            pi_actions = torch.zeros(horizon, num_pi_trajs, self.args.action_dim, device=device)
            s = state.repeat(num_pi_trajs, 1)
            for t in range(horizon):
                k, z = self.pi(s, self.args.min_std, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)

                pi_actions[t] = self.dealRaw(k, z)

                s, _, _ = self.model.next(s, pi_actions[t], reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)

        # Initialize state and parameters
        s = state.repeat(self.args.mpc_popsize+num_pi_trajs, 1)

        if self.args.embed:
            return self.embed_cem(state, t0, eval_mode)

        kmean = torch.ones(horizon, self.args.k_dim, device=device)
        kmean /= kmean.sum(-1).unsqueeze(-1)
        
        zmean = torch.zeros(horizon, self.args.all_z_dim, device=device)
        std = 2*torch.ones(horizon, self.args.all_z_dim, device=device)
        mean = {'k': kmean, 'z': zmean}
        if not t0 and hasattr(self, '_prev_mean'):
            mean['k'][:-1] = self._prev_mean['k'][1:]
            mean['z'][:-1] = self._prev_mean['z'][1:]

        # Iterate CEM
        for i in range(self.args.cem_iter):
            actions = self.sample_from_N(mean, std)
            if num_pi_trajs > 0:
                actions = torch.cat([actions, pi_actions], dim=1)
                
            # Compute elite actions
            value = self.estimate_value(s, actions, horizon, local_step).nan_to_num_(0)
            elite_idxs = torch.topk(value.squeeze(1), self.args.mpc_num_elites, dim=0).indices
            elite_value = value[elite_idxs]  # [num_elite, 1]
            elite_actions = actions[:, elite_idxs]  # [horizon, num_elite, a_dim]

            max_value = elite_value.max(0)[0]

            # Update k parameters
            # k_score is k weights, softmax(elite_value-max)
            k_score = torch.exp(self.args.mpc_temperature*(elite_value - max_value))
            k_score /= k_score.sum(0)  # [num_elite, 1]
            kelites = elite_actions[:, :, :self.args.k_dim]
            _kmean = torch.sum(k_score.unsqueeze(0) * kelites, dim=1) / (k_score.sum(0) + 1e-9)

            # Update z parameters
            zelites = elite_actions[:, :, self.args.k_dim:]
            k_all = kelites.argmax(-1).unsqueeze(-1)  # [horizon, num_elite, 1]
            z_score = elite_value.unsqueeze(0).repeat([horizon, 1, 1])  # [horizon, num_elite, 1]
            _zmean, _std = torch.zeros_like(mean['z']), torch.zeros_like(std)

            for ki in range(self.args.k_dim):
                selected_ind = (k_all == ki)  # selected discrete type, [horizon, num_elite, 1]
                zis = zelites[:, :, :self.args.par_size[ki]]
                # zi: [horizon, num_elite, z_dim], = zi if selected else 0
                zi = torch.where(selected_ind, zis, torch.zeros_like(zis).to(device))

                # weight: [horizon, num_elite, z_dim], = softmax(selected(z))
                weight = torch.where(selected_ind, z_score, torch.tensor([float("-Inf")]).to(device))
                weight = torch.exp(self.args.mpc_temperature*(weight - max_value))
                weight_sum = weight.squeeze(-1).sum(1).reshape([-1, 1, 1]).repeat(1, self.args.mpc_num_elites, 1)
                weight /= (weight_sum + 1e-9)
                
                _zimean = torch.sum(weight * zi, dim=1) / (weight.sum(1) + 1e-9)
                _zistd = torch.sqrt(torch.sum(weight * (zi - _zimean.unsqueeze(1)) ** 2, dim=1) / (weight.sum(1) + 1e-9))

                ind_start = self.args.offset[ki]
                ind_end = ind_start + self.args.par_size[ki]

                if_non_select = selected_ind.squeeze(-1).sum(1).unsqueeze(-1)
                _zimean = torch.where(if_non_select==0, mean['z'][:, ind_start:ind_end], _zimean)
                _zistd = torch.where(if_non_select==0, std[:, ind_start:ind_end], _zistd)

                _zmean[:, ind_start:ind_end] = _zimean
                _std[:, ind_start:ind_end] = _zistd

            mean['k'] = self.args.mpc_alpha * mean['k'] + (1 - self.args.mpc_alpha) * _kmean
            mean['z'] = self.args.mpc_alpha * mean['z'] + (1 - self.args.mpc_alpha) * _zmean
            std = self.args.mpc_alpha * std + (1 - self.args.mpc_alpha) * _std

        # Outputs
        score = k_score.squeeze(1).cpu().numpy()
        actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
        self._prev_mean = mean
        mean, std = actions[0], _std[0]

        k = mean[:self.args.k_dim].argmax()
        z = mean[self.args.k_dim:self.args.k_dim+self.args.par_size[k]]
        
        if not eval_mode:
            ind_start = self.args.offset[k]
            ind_end = ind_start + self.args.par_size[k]
            z += std[ind_start:ind_end] * torch.randn(self.args.par_size[k], device=device)

        k = k.item()
        if k:
            z = 0
        else:
            z = z.item()

        # # # print(state.shape, actions[0].shape)
        pred_s, reward, ci = self.model.next(state, actions[0].unsqueeze(0), reparameterize=False, return_log_prob=False, deterministic=True)

        episilon = 1e-4
        dist2 = torch.sum(torch.square(pred_s[:, 2:4]), dim=-1)
        hard_r = torch.where(state[:, 4]-5/6>-episilon, 0., - dist2)
        
        condition = torch.logical_and((state[:, 4])-5/6<episilon, torch.tensor(k))
        condition = torch.logical_and(condition, dist2-0.04<episilon)

        hard_r = torch.where(condition, 20-dist2, hard_r)
        hard_r = hard_r.unsqueeze(-1)

        hard_c = (1 - torch.logical_or((hard_r > 4), hard_r.abs()<episilon).long())
        print("pred:", hard_r.item(), reward.item())
        # print("pred:", pred_s.cpu().numpy()[0][:4], hard_r.cpu().numpy().item(), 1.-hard_c.cpu().numpy().item())
            
        return k, z

    def train(self, replay_buffer, batch_size):
        state, k, z, _, _, _, next_state, _, reward, not_done = replay_buffer.sample(batch_size)

        k_onehot = torch.nn.functional.one_hot(k.long().flatten(), num_classes=self.args.k_dim)
        a = torch.cat([k_onehot, z], dim=-1)

        self.optim.zero_grad()

        s_pred, r_pred, c_pred = self.model.next(state, a, reparameterize=True, return_log_prob=True, deterministic=False)

        sl = self.mse_loss(s_pred, next_state)
        rl = self.mse_loss(r_pred, reward)
        dl = self.ce_loss(c_pred, not_done.flatten().long())

        Loss = sl + rl + dl

        Loss.backward()
        self.optim.step()

        with torch.no_grad():
            episilon = 1e-4
            dist2 = torch.sum(torch.square(s_pred[:, 2:4]), dim=-1)
            hard_r = torch.where(state[:, 4]-5/6>-episilon, 0., - dist2)

            condition = torch.logical_and((state[:, 4])-5/6<episilon, k.flatten())
            condition = torch.logical_and(condition, dist2-0.04<episilon)

            hard_r = torch.where(condition, 20-dist2, hard_r)
            hard_r = hard_r.unsqueeze(-1)

            hard_c = (1 - torch.logical_or((hard_r > 4), hard_r.abs()<episilon).long())

            hard_rl = self.mse_loss(hard_r, reward)
            hard_ca = (not_done==hard_c).flatten().float().mean()

        return sl.item(), rl.item(), dl.item(), hard_rl.item(), hard_ca.item()


def pad_action(act, act_param):
    if act==0:
        action = np.hstack(([1],act_param*math.pi, [1], [0]))
    else:
        action = np.hstack(([1],act_param*math.pi, [0], [1]))

    return [action]


# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
def evaluate(env, policy, max_steps, episodes=10):
    returns = []
    success = []
    epioside_steps = []
    vis=True

    for _ in range(episodes):
        state = env.reset()

        if vis:
            env.render()

        t = 0
        total_reward = 0.
        valid_time = 0
        flag = 0
        for j in range(max_steps):
            t += 1

            state = np.array(state, dtype=np.float32, copy=False)[0]
            state_time = (j/max_steps) * 2 - 1
            state_catch= (valid_time / 12) * 2 - 1
            state = np.concatenate([state, [state_catch, state_time]])

            with torch.no_grad():
                discrete_action, all_parameter_action = policy.select_action(state, 0, j, True, j==0)
            # discrete_action = np.argmax(all_discrete_action)

            if discrete_action:
                valid_time += 1

            # print(discrete_action, all_parameter_action)
            action = pad_action(discrete_action, all_parameter_action)
            state, reward, done_n, _ = env.step(action)
            # exit()

            if vis:
                env.render()

            done = all(done_n)
            reward = reward[0]
            total_reward += reward
            if reward > 4:
                flag = 1
                done = True
            if reward == 0:
                done = True

            # print('true:', state[0], action, reward, done, discrete_action, all_parameter_action, '\n')
            # print('true:', state[0], reward, done, '\n')
            print('true:', reward)
            # exit()

            if done or j == max_steps - 1:
                epioside_steps.append(j)
                break
        if flag == 1:
            success.append(1)
        else:
            success.append(0)
        returns.append(total_reward)

    print("---------------------------------------")
    print(
        f"Evaluation over {episodes} episodes: {np.array(returns[-episodes:]).mean():.3f} {np.array(success[-episodes:]).mean():.3f} "
        f"{np.array(epioside_steps[-episodes:]).mean():.3f} ")
    print("---------------------------------------")
    return np.array(returns[-episodes:]).mean(), np.array(success[-episodes:]).mean(), np.array(
        epioside_steps[-episodes:]).mean()

def run(args):
    file_name = f"{args.policy}_{args.env}_{args.seed}"
    print("---------------------------------------")
    print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")

    if not os.path.exists("./results"):
        os.makedirs("./results")

    # if args.save_model and not os.path.exists("./models"):
    #     os.makedirs("./models")

    env = make_env(args.env)
    obs_shape_n = [env.observation_space[i].shape for i in range(env.n)]
    obs_n = env.reset()

    # Set seeds
    env.seed(args.seed)
    np.random.seed(args.seed)
    print(obs_shape_n)
    torch.manual_seed(args.seed)

    state_dim = obs_shape_n[0][0]
    state_dim = state_dim + 2

    discrete_action_dim = 2
    
    action_parameter_sizes = np.array([1, 0])

    parameter_action_dim = 1
    discrete_emb_dim = discrete_action_dim
    parameter_emb_dim = parameter_action_dim
    max_action = 1.0

    print("state_dim", state_dim)
    print("discrete_action_dim", discrete_action_dim)
    print("parameter_action_dim", parameter_action_dim)

    kwargs = {
        "state_dim": state_dim,
        "discrete_action_dim": discrete_action_dim,
        "parameter_action_dim": parameter_action_dim,
        "max_action": max_action,
        "discount": args.discount,
        "tau": args.tau,
    }

    args.state_dim = state_dim
    args.k_dim = discrete_action_dim
    args.all_z_dim = parameter_action_dim
    args.par_size = action_parameter_sizes
    args.ub, args.lb = 1., -1.

    args.z_dim = action_parameter_sizes.max()
    args.action_dim = args.k_dim + args.z_dim
    args.max_action = args.ub
    
    args.scale = args.ub - args.lb
    args.offsets = args.lb
    args.offset = [args.par_size[:i].sum() for i in range(args.k_dim)]

    if args.save_points:
        save_points(args)
    
    policy = MPC(args)

    replay_buffer = utils.ReplayBuffer(state_dim, discrete_action_dim=1,
                                       parameter_action_dim=1,
                                       all_parameter_action_dim=parameter_action_dim,
                                       discrete_emb_dim=discrete_emb_dim,
                                       parameter_emb_dim=parameter_emb_dim,
                                       max_size=int(1e5))

    # Evaluate untrained policy
    # evaluations = [eval_policy(policy, args.env, args.seed)]
    total_reward = 0.
    Reward = []
    Reward_100 = []
    Test_Reward = []
    max_steps = 25
    cur_step=0
    flag=0
    Test_success=[]
    returns=[]
    success=[]
    Test_epioside_step=[]
    total_timesteps = 0

    policy.model.load_state_dict(torch.load("result/TD3/direction_catch/mpc_catch/model_25.pth", map_location=torch.device('cpu')))
    evaluate(env, policy, max_steps=25, episodes=10)
    exit()

    while total_timesteps < args.max_timesteps:
        t = 0
        valid_time = 0
        used_k = 0

        state = env.reset()
        
        state = np.array(state, dtype=np.float32, copy=False)[0]
        state_time = (t/max_steps) * 2 - 1
        state_catch= (valid_time / 12) * 2 - 1
        state = np.concatenate([state, [state_catch, state_time]])

        with torch.no_grad():
            discrete_action, all_parameter_action = policy.select_action(state, total_timesteps, 0, t0=True)
        # 探索
        # if t < args.epsilon_steps:
        #     epsilon = args.expl_noise_initial - (args.expl_noise_initial - args.expl_noise) * (
        #             t / args.epsilon_steps)
        # else:
        #     epsilon = args.expl_noise

        # all_discrete_action = (
        #         all_discrete_action + np.random.normal(0, max_action * args.expl_noise, size=discrete_action_dim)
        # ).clip(-max_action, max_action)
        # all_parameter_action = (
        #         all_parameter_action + np.random.normal(0, max_action * args.expl_noise, size=parameter_action_dim)
        # ).clip(-max_action, max_action)
        # discrete_action = np.argmax(all_discrete_action)

        action = pad_action(discrete_action, all_parameter_action)
        episode_reward = 0.
        flag = 0

        for i in range(max_steps):
            total_timesteps += 1
            cur_step = cur_step + 1
            next_state, reward, done_n, _ = env.step(action)
            done = all(done_n)
            reward = reward[0]
            if reward > 4:
                flag = 1
                done = True
            if reward == 0:
                done = True

            # if discrete_action:
            #     valid_time += 1
            used_k = used_k+1 if discrete_action else used_k
            if valid_time<=10 and discrete_action and np.sum(np.square(next_state[0][2:4]))>0.04:
                valid_time += 1

            next_state = np.array(next_state, dtype=np.float32, copy=False)[0]
            state_time = ((i+1) / max_steps) * 2 - 1
            state_catch= (valid_time / 12) * 2 - 1
            next_state = np.concatenate([next_state, [state_catch, state_time]])

            replay_buffer.add(state, discrete_action=discrete_action, parameter_action=all_parameter_action, all_parameter_action=None,
                              discrete_emb=None,
                              parameter_emb=None,
                              next_state=next_state,
                              state_next_state=None,
                              reward=reward, done=done)

            with torch.no_grad():
                next_discrete_action, next_all_parameter_action = policy.select_action(next_state, total_timesteps, i+1)

            # next_all_discrete_action = (
            #         next_all_discrete_action + np.random.normal(0, max_action * args.expl_noise,
            #                                                     size=discrete_action_dim)
            # ).clip(-max_action, max_action)
            # next_all_parameter_action = (
            #         next_all_parameter_action + np.random.normal(0, max_action * args.expl_noise,
            #                                                      size=parameter_action_dim)
            # ).clip(-max_action, max_action)
            # next_discrete_action = np.argmax(next_all_discrete_action)

            next_action = pad_action(next_discrete_action, next_all_parameter_action)

            discrete_action, all_parameter_action, action = next_discrete_action, next_all_parameter_action, next_action

            state = next_state

            if cur_step >= args.start_timesteps:
                sl, rl, cl, hard_rl, hard_ca = policy.train(replay_buffer, args.batch_size)
                if args.save_points:
                    wandb.log({"s_loss": sl, "r_loss": rl, "c_loss": cl,
                               "hard_rl": hard_rl, "hard_ca": hard_ca})

            episode_reward += reward

            if total_timesteps % args.eval_freq == 0:
                print(
                    '{0:5s} R:{1:.4f} r100:{2:.4f} success:{3:.4f}'.format(str(total_timesteps), total_reward / (t + 1),
                                                                           np.array(returns[-100:]).mean(),
                                                                           np.array(success[-100:]).mean()))
                Reward.append(total_reward / (t + 1))
                Reward_100.append(np.array(returns[-100:]).mean())

                Test_Reward_50, Test_success_rate, Test_epioside_step_50 = evaluate(env, policy, max_steps=25,
                                                                                    episodes=10)
                
                if args.save_points:
                    wandb.log({"test_r": Test_Reward_50, "test_s": Test_success_rate, "test_epi": Test_epioside_step_50})

                Test_Reward.append(Test_Reward_50)
                Test_success.append(Test_success_rate)
                Test_epioside_step.append(Test_epioside_step_50)

                save_txt(args, Reward_100, Test_Reward, Test_success, Test_epioside_step)
                torch.save(policy.model.state_dict(), f"result/TD3/direction_catch/mpc_catch/model_{args.seed}.pth")

            if done or i == max_steps - 1:
                obs_n = env.reset()
                break

        if flag == 1:
            success.append(1)
        else:
            success.append(0)

        t += 1
        returns.append(episode_reward)
        total_reward += episode_reward

def save_txt(args, Reward_100, Test_Reward, Test_success, Test_epioside_step):
    print("save txt")
    dir = "result/TD3/direction_catch"
    data = "mpc_catch"
    redir = os.path.join(dir, data)
    if not os.path.exists(redir):
        os.makedirs(redir)
    print("redir", redir)
    # title1 = "Reward_td3_direction_catch_"
    title2 = "Reward_100_td3_direction_catch_"
    title3 = "Test_Reward_td3_direction_catch_"
    title4 = "Test_success_td3_direction_catch_"
    title5 = "Test_epioside_step_td3_direction_catch_"

    # np.savetxt(os.path.join(redir, title1 + "{}".format(str(args.seed) + ".csv")), Reward, delimiter=',')
    np.savetxt(os.path.join(redir, title2 + "{}".format(str(args.seed) + ".csv")), Reward_100, delimiter=',')
    np.savetxt(os.path.join(redir, title3 + "{}".format(str(args.seed) + ".csv")), Test_Reward, delimiter=',')
    np.savetxt(os.path.join(redir, title4 + "{}".format(str(args.seed) + ".csv")), Test_success, delimiter=',')
    np.savetxt(os.path.join(redir, title5 + "{}".format(str(args.seed) + ".csv")), Test_epioside_step, delimiter=',')


def make_env(scenario_name):
    from multiagent.environment import MultiAgentEnv
    import multiagent.scenarios as scenarios

    # load scenario from script
    scenario = scenarios.load(scenario_name + ".py").Scenario()
    # create world
    world = scenario.make_world()
    # create multiagent environment

    env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation)
    return env

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--policy", default="P-TD3")  # Policy name (TD3, DDPG or OurDDPG)
    parser.add_argument("--env", default='simple_catch')  # platform goal HFO
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--start_timesteps", default=128, type=int)  # Time steps initial random policy is used
    parser.add_argument("--eval_freq", default=50, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--max_episodes", default=5000, type=int)  # Max time steps to run environment
    parser.add_argument("--max_embedding_episodes", default=1e5, type=int)  # Max time steps to run environment
    parser.add_argument("--max_timesteps", default=1_000_000, type=float)  # Max time steps to run environment for
    parser.add_argument("--epsilon_steps", default=1000, type=int)  # Max time steps to epsilon environment
    parser.add_argument("--expl_noise_initial", default=1.0)  # Std of Gaussian exploration noise 1.0
    parser.add_argument("--expl_noise", default=0.1)  # Std of Gaussian exploration noise 0.1
    parser.add_argument("--batch_size", default=128, type=int)  # Batch size for both actor and critic
    parser.add_argument("--discount", default=0.99)  # Discount factor
    parser.add_argument("--tau", default=0.005)  # Target network update rate
    parser.add_argument("--policy_noise", default=0.2)  # Noise added to target policy during critic update
    parser.add_argument("--noise_clip", default=0.5)  # Range to clip target policy noise
    parser.add_argument("--policy_freq", default=2, type=int)  # Frequency of delayed policy updates
    parser.add_argument("--save_model", action="store_true")  # Save model and optimizer parameters
    parser.add_argument("--load_model", default="")  # Model load file name, "" doesn't load, "default" uses file_name

    parser.add_argument("--save_points", default=0, type=int)  # Frequency of delayed policy updates

    args = parser.parse_args()
    
    run(args)

    # for i in range(0,5):
    #     args.seed=i
    #     run(args)
