import numpy as np
import torch
import gym
import argparse
import os
import gym_platform

from common import ClickPythonLiteralOption
from common.platform_domain import PlatformFlattenedActionWrapper
from common.goal_domain import GoalFlattenedActionWrapper, GoalObservationWrapper
from common.wrappers import ScaledStateWrapper, ScaledParameterisedActionWrapper

from mpc_model.mpc_buffer import ReplayBuffer
# from mpc_model.mpc_model import DynamicModel, RewardModel
from mpc_model.better_mpc_model import nDynamicModel
from mpc_model.hierachical_model import hmodel, hreplayBuffer

from mpc_model.mpc_planning import MPC
import matplotlib.pyplot as plt

from mpc_model.utils import train
from HyAR_RL.utils import ReplayBuffer as embedbuffer
from embedding import ActionRepresentation_vae
from mpc_model.hyar_model import vae_train

import inspect
import wandb
import math

from mpc_model.hard_code_goal_utils import SCALE_VECTOR, SHIFT_VECTOR



def save_points(args):
    run = wandb.init(
        project="pamdp-mpc",
        config=args,
        dir="../scratch/wandb"
    )


def pad_action(act, act_param, par_size, envname):

    if envname == "simple_catch":
        if act == 0:
            action = np.hstack(([1], act_param * math.pi, [1], [0]))
        else:
            action = np.hstack(([1], act_param * math.pi, [0], [1]))
        return [action]

    else:
        params = [np.zeros((par_size[0],), dtype=np.float32), np.zeros((par_size[1],), dtype=np.float32), np.zeros((par_size[2],), dtype=np.float32)]
        params[act][:] = act_param
        return (act, params)


def unwrap_action(mpc, dm, rm, cm, obs, training=True, debug=False,
                  action_parameter_sizes=[2, 1, 1],
                  max_action=1.0,
                  action_rep=None):
    discrete_action_dim = len(action_parameter_sizes)
    parameter_action_dim = max(action_parameter_sizes)

    with torch.no_grad():
        all_discrete_action, all_parameter_action, pred_s = mpc.act(dm, rm, cm, obs, action_rep, debug)
        # print(all_discrete_action, all_parameter_action)

    if training:
        # all_discrete_action = (
        #         all_discrete_action + np.random.normal(0, max_action * args.expl_noise, size=discrete_action_dim)
        # ).clip(-max_action, max_action)
        all_parameter_action = (
                all_parameter_action + np.random.normal(0, max_action * args.expl_noise, size=parameter_action_dim)
        ).clip(-max_action, max_action)

    discrete_action = np.argmax(all_discrete_action)

    # offset = np.array([action_parameter_sizes[i] for i in range(discrete_action)], dtype=int).sum()
    # parameter_action = all_parameter_action[offset:offset + action_parameter_sizes[discrete_action]]

    parameter_action = np.zeros(all_parameter_action.shape)
    parameter_action[:action_parameter_sizes[discrete_action]] = all_parameter_action[:action_parameter_sizes[discrete_action]]
    
    # print(discrete_action, parameter_action, action_parameter_sizes, parameter_action[:action_parameter_sizes[discrete_action]], all_parameter_action[:action_parameter_sizes[discrete_action]])
    action = pad_action(discrete_action, parameter_action, action_parameter_sizes, envname=mpc.envname)

    # print(all_discrete_action, all_parameter_action, action)
    return all_discrete_action, all_parameter_action, action, pred_s


# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
def evaluate(env, mpc, dm, rm, cm, episodes=100, vis=False, action_rep=None):
    returns = []
    epioside_steps = []
    debug = True

    for epi in range(episodes):
        print(f"test: {epi}")

        if mpc.envname == "simple_catch":
            state = env.reset()[0]
        else:
            state, _ = env.reset()

        if vis:
            env.render()

        terminal = False
        t = 0
        total_reward = 0.
        mpc.reset()

        while not terminal:
            t += 1
            state = np.array(state, dtype=np.float32, copy=False)
            discrete_action, parameter_action, action, pred_s = unwrap_action(mpc, dm, rm, cm, state, False, False, action_rep=action_rep, action_parameter_sizes=mpc.par_size)
            debug = False

            if mpc.envname == "simple_catch":
                # print(discrete_action, parameter_action, action)
                state, reward, terminal_n, _ = env.step(action)
                state = state[0]
                terminal = all(terminal_n)
                reward = reward[0]

                if reward > 4:
                    terminal = True
                if reward == 0:
                    terminal = True
                if terminal or t == 25 - 1:
                    epioside_steps.append(t)
                    break

            else:
                (state, _), reward, terminal, _ = env.step(action)

            # print('real')
            # print(state)
            # # print(pred_s)
            # print(reward)
            # diff = state - pred_s
            # print(diff[:5])
            # print(diff[5:10])
            # print(diff[10:14])
            # print(diff[14:])
            # # print(diff)
            # # exit()
            # print("\n")

            if vis:
                env.render()
                
            total_reward += reward
        epioside_steps.append(t)
        returns.append(total_reward)
        # exit()
    print("---------------------------------------")
    print(
        f"Evaluation over {episodes} episodes_rewards: {np.array(returns).mean():.3f} epioside_steps: {np.array(epioside_steps).mean():.3f}")
    print("---------------------------------------")
    return np.array(returns).mean(), np.array(epioside_steps).mean()


def save_local(redir, Reward_100, Test_Reward_100, Test_epioside_step_100):
    print("save txt || redir:", redir)
    title2 = "Reward_100_td3_platform_"
    title3 = "Test_Reward_100_td3_platform_"
    title4 = "Test_epioside_step_100_td3_platform_"
    np.savetxt(os.path.join(redir, title2 + "{}".format(str(args.seed) + ".csv")), Reward_100, delimiter=',')
    np.savetxt(os.path.join(redir, title3 + "{}".format(str(args.seed) + ".csv")), Test_Reward_100, delimiter=',')
    np.savetxt(os.path.join(redir, title4 + "{}".format(str(args.seed) + ".csv")), Test_epioside_step_100,
               delimiter=',')

def run(args):
    print("---------------------------------------")
    print(f"Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")

    if args.save_points:
        save_points(vars(args))

    if not os.path.exists("./results"):
        os.makedirs("./results")

    if args.env == "Platform-v0":
        env = gym.make(args.env)
        env = ScaledStateWrapper(env)
        env = PlatformFlattenedActionWrapper(env)
        env = ScaledParameterisedActionWrapper(env)
        discrete_action_dim = env.action_space.spaces[0].n
        action_parameter_sizes = np.array(
            [env.action_space.spaces[i].shape[0] for i in range(1, discrete_action_dim + 1)])
        parameter_action_dim = int(action_parameter_sizes.sum())
        max_action = 1.0
        print("state_dim", state_dim)
        print("discrete_action_dim", discrete_action_dim)
        print("parameter_action_dim", parameter_action_dim)

    elif args.env == "Goal-v0":
        env = gym.make('Goal-v0')
        env = GoalObservationWrapper(env)
        env = GoalFlattenedActionWrapper(env)
        env = ScaledParameterisedActionWrapper(env)
        env = ScaledStateWrapper(env)
        discrete_action_dim = env.action_space.spaces[0].n
        action_parameter_sizes = np.array(
            [env.action_space.spaces[i].shape[0] for i in range(1, discrete_action_dim + 1)])
        parameter_action_dim = int(action_parameter_sizes.sum())
        max_action = 1.0
        print("state_dim", state_dim)
        print("discrete_action_dim", discrete_action_dim)
        print("parameter_action_dim", parameter_action_dim)

    elif args.env == "simple_catch":
        from multiagent.environment import MultiAgentEnv
        import multiagent.scenarios as scenarios
        scenario = scenarios.load('simple_catch' + ".py").Scenario()
        world = scenario.make_world()
        env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation)
        
        obs_shape_n = [env.observation_space[i].shape for i in range(env.n)]
        state_dim = obs_shape_n[0][0]
        discrete_action_dim = 2
        parameter_action_dim = 1
        action_parameter_sizes = np.array([1, 1])
        discrete_emb_dim = discrete_action_dim
        parameter_emb_dim = parameter_action_dim
        max_action = 1.0
        print("state_dim", state_dim)
        print("discrete_action_dim", discrete_action_dim)
        print("parameter_action_dim", parameter_action_dim)


    # Set seeds
    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    # state_dim = env.observation_space.spaces[0].shape[0]

    args.sparse = 1 if args.env == "Goal-v0" else 0

    args.state_dim = state_dim
    args.discrete_action_dim = discrete_action_dim
    args.parameter_action_dim = parameter_action_dim
    args.action_parameter_sizes = action_parameter_sizes
    args.max_action = max_action

    state, terminal = env.reset(), False

    if args.which_model == 'normal':
        model = nDynamicModel(args, label='state')
        reward_model = nDynamicModel(args, label='reward')
        continuous = nDynamicModel(args, label='continuous')

        replay_buffer = ReplayBuffer(state_dim=state_dim,
                                 discrete_action_dim=discrete_action_dim,
                                 all_parameter_action_dim=parameter_action_dim,
                                 batch_size=args.train_bs,
                                 max_size=int(1e5))
        
    elif args.which_model == 'h':
        model = hmodel(args, label='state')
        reward_model = hmodel(args, label='reward')
        continuous = hmodel(args, label='continuous')

        replay_buffer = hreplayBuffer(state_dim=state_dim,
                                 discrete_action_dim=discrete_action_dim,
                                 all_parameter_action_dim=parameter_action_dim,
                                 max_size=int(1e5))
    else:
        assert f"not implemented with {args.which_model}, only h or normal"
        
    if args.embed:
    # elif args.which_model == 'hyar':
        discrete_emb_dim = discrete_action_dim * 2
        parameter_emb_dim = parameter_action_dim * 2
        action_rep = ActionRepresentation_vae.Action_representation(state_dim=state_dim,
                                                                  action_dim=discrete_action_dim,
                                                                  parameter_action_dim=action_parameter_sizes.max(),
                                                                  reduced_action_dim=discrete_emb_dim,
                                                                  reduce_parameter_action_dim=parameter_emb_dim
                                                                  )
    else:
        action_rep = None

    Test_Reward_100 = []

    print('load mpc')
    if args.mpc_type == "CEM" or args.mpc_type == "Random":
        mpc = MPC(args, action_rep)
    # elif args.mpc_type == "Enumeration":
    #     mpc = rand(args)
    else:
        raise "NOOO"

    Test_Reward, Test_epioside_step = evaluate(env, mpc, model, reward_model, continuous, episodes=1, vis=args.visualise, action_rep=action_rep)
    Test_Reward_100.append(Test_Reward)
    if args.save_points:
        wandb.log({"Test_Reward_100": Test_Reward,
                    "Test_epioside_step_100": Test_epioside_step})

    print("---------- START PRE-TRAIN ----------")
    for epi in range(args.pretrain_episodes):
        
        if mpc.envname == "simple_catch":
            obs = env.reset()[0]
        else:
            obs, _ = env.reset()
            
        terminal = False
        t = 0

        while not terminal:
            t += 1

            if mpc.envname == "simple_catch":
                da = np.zeros(2)
                da[np.random.randint(0, 2)] = 1.
                pa = np.random.random() * 2 - 1
                action = pad_action(da.argmax(), pa, action_parameter_sizes, envname=args.env)
                # print(da, pa, action)
            else:
                raw_action = env.action_space.sample()
                r1 = raw_action[1:][raw_action[0]] if args.env == "Goal-v0" else raw_action[1]
                action = pad_action(raw_action[0], r1, action_parameter_sizes, envname=args.env)
                da = np.zeros(args.discrete_action_dim)
                da[action[0]] = 1
                pa = action[1]

            if mpc.envname == "simple_catch":
                obs_next, reward, terminal_n, _ = env.step(action)
                obs_next = obs_next[0]
                terminal = all(terminal_n)
                reward = reward[0]

                if reward > 4:
                    terminal = True
                if reward == 0:
                    terminal = True
                if terminal or t == 25 - 1:
                    break
            else:
                (obs_next, _), reward, terminal, _ = env.step(action)  # terminal=1: episode terminated

            pa = np.concatenate(pa) if args.env == "Goal-v0" else pa

            if mpc.envname == "simple_catch":
                obs = np.array(obs)
                obs_next = np.array(obs_next)
            
            replay_buffer.add(obs,
                            discrete_action=da,
                            all_parameter_action=pa,
                            next_state=obs_next-obs,
                            reward=reward, 
                            terminal=terminal)

            obs = obs_next

    if args.embed:
        recon_s_loss = []
        VAE_batch_size = 64
        save_dir = "result/platform_model/mix/1.0/0526"
        save_dir = os.path.join(save_dir, "{}".format(str(66)))
        print("save_dir", save_dir)
        os.makedirs(save_dir, exist_ok=True)
        vae_save_model = True

        c_rate, recon_s = vae_train(action_rep=action_rep, train_step=1000, 
                                    replay_buffer=replay_buffer,
                                    batch_size=VAE_batch_size,
                                    # batch_size=1,
                                    save_dir=save_dir, vae_save_model=vae_save_model, embed_lr=1e-4,
                                    par_size=action_parameter_sizes)

    train_models = [['state', model], ['reward', reward_model], ['terminal', continuous]] if args.use_terminal else [['state', model], ['reward', reward_model]]

    # train_data = {
    #     'state':replay_buffer.sample(args.dm_datasetlen),
    #     'reward':replay_buffer.sample(args.r_datasetlen),
    #     'terminal':replay_buffer.sample(args.c_datasetlen),
    # } if args.use_terminal else{
    #     'state':replay_buffer.sample(args.r_datasetlen),
    #     'reward':replay_buffer.sample(args.c_datasetlen),
    # }
    
    train(models=train_models, 
        data=replay_buffer, 
        logger=args.save_points,
        model_type=args.which_model, action_rep=action_rep)
    
    print("---------- DONE PRE-TRAIN ----------")

    Test_Reward, Test_epioside_step = evaluate(env, mpc, model, reward_model, continuous, episodes=5, vis=args.visualise, action_rep=action_rep)
    Test_Reward_100.append(Test_Reward)
    if args.save_points:
        wandb.log({"Test_Reward_100": Test_Reward,
                    "Test_epioside_step_100": Test_epioside_step})

    returns = []
    Reward_100 = []
    max_steps = 25
    
    Test_epioside_step_100 = []
    total_timesteps = 0
    epi = 0
    
    dir = "result/MPC/platform"
    data = args.save_dir
    redir = os.path.join(dir, data)
    if not os.path.exists(redir):
        os.makedirs(redir)

    while total_timesteps < args.max_timesteps:
        cur_step = 0
        epi += 1
        
        if mpc.envname == "simple_catch":
            state = env.reset()[0]
        else:
            state, _ = env.reset()

        state = np.array(state, dtype=np.float32, copy=False)

        discrete_action, parameter_action, action, pred_s = unwrap_action(mpc, model, reward_model, continuous, state, action_rep=action_rep, action_parameter_sizes=action_parameter_sizes)
        episode_reward = 0.

        mpc.reset()  # reset mean, var
        for i in range(max_steps):
            cur_step += 1
            total_timesteps += 1
            print(f"total_timesteps: {total_timesteps}, episode: {epi}")
            
            if mpc.envname == "simple_catch":
                next_state, reward, terminal_n, _ = env.step(action)
                next_state = next_state[0]
                terminal = all(terminal_n)
                reward = reward[0]

                if reward > 4:
                    terminal = True
                if reward == 0:
                    terminal = True
            else:
                (next_state, _), reward, terminal, _ = env.step(action)  # terminal=1: episode terminated
            
            if args.change_r:
                reward = -1 if terminal else reward

            next_state = np.array(next_state, dtype=np.float32, copy=False)
            episode_reward += reward

            # print(discrete_action, parameter_action, action)
            replay_buffer.add(state,
                              discrete_action=discrete_action,
                              all_parameter_action=parameter_action,
                              next_state=next_state-state,
                              reward=reward, terminal=terminal)

            discrete_action, parameter_action, action, pred_s = unwrap_action(mpc, model, reward_model, continuous, next_state, action_rep=action_rep, action_parameter_sizes=action_parameter_sizes)
            state = next_state

            if total_timesteps % args.train_every == 0:
                train_models = [['state', model], ['reward', reward_model], ['terminal', continuous]] if args.use_terminal else [['state', model], ['reward', reward_model]]
    
                # train_data = {
                #     'state': replay_buffer.sample(args.dm_datasetlen),
                #     'reward': replay_buffer.sample(args.r_datasetlen),
                #     'terminal': replay_buffer.sample(args.c_datasetlen),
                # } if args.use_terminal else{
                #     'state': replay_buffer.sample(args.r_datasetlen),
                #     'reward': replay_buffer.sample(args.c_datasetlen),
                # }
                
                train(models=train_models, 
                    data=replay_buffer, 
                    logger=args.save_points,
                    model_type=args.which_model, action_rep=action_rep)

            if args.embed and (total_timesteps % 10 == 0) and (total_timesteps>=100):
                print("--------Training VAE---------")
                c_rate, recon_s = vae_train(action_rep=action_rep, train_step=1, 
                                            replay_buffer=replay_buffer,
                                            batch_size=VAE_batch_size, save_dir=save_dir, vae_save_model=vae_save_model,
                                            embed_lr=1e-4,
                                            par_size=action_parameter_sizes)
                recon_s_loss.append(recon_s)

            if total_timesteps % args.eval_freq == 0:
                env.reset()
                # if mpc.envname == "simple_catch":

                # else:
                #     while not terminal:
                #         raw_action = env.action_space.sample()
                #         r1 = raw_action[1:][raw_action[0]] if args.env == "Goal-v0" else raw_action[1]
                #         action = pad_action(raw_action[0], r1, action_parameter_sizes, envname=args.env)
                #         (state, _), reward, terminal, _ = env.step(action)  # terminal=1: episode terminated

                Reward_100.append(np.array(returns[-100:]).mean())
                Test_Reward, Test_epioside_step = evaluate(env, mpc, model, reward_model, continuous, episodes=5, vis=args.visualise, action_rep=action_rep)
                Test_Reward_100.append(Test_Reward)
                Test_epioside_step_100.append(Test_epioside_step)

                print(f"Step: {total_timesteps} ||R100: {np.array(returns[-100:]).mean()} || "
                      f"Test_Reward_100: {Test_Reward} || Test_epioside_step_100 {Test_epioside_step}")

                if args.save_points:
                    wandb.log({"Test_Reward_100": Test_Reward,
                               "Test_epioside_step_100": Test_epioside_step})
                    
                save_local(redir, Reward_100, Test_Reward_100, Test_epioside_step_100)

            if terminal:
                break

        returns.append(episode_reward)

    # print("save txt")
    # dir = "result/MPC/platform"
    # data = args.save_dir
    # redir = os.path.join(dir, data)
    # if not os.path.exists(redir):
    #     os.makedirs(redir)
    # print("redir", redir)
    # title2 = "Reward_100_td3_platform_"
    # title3 = "Test_Reward_100_td3_platform_"
    # title4 = "Test_epioside_step_100_td3_platform_"
    # np.savetxt(os.path.join(redir, title2 + "{}".format(str(args.seed) + ".csv")), Reward_100, delimiter=',')
    # np.savetxt(os.path.join(redir, title3 + "{}".format(str(args.seed) + ".csv")), Test_Reward_100, delimiter=',')
    # np.savetxt(os.path.join(redir, title4 + "{}".format(str(args.seed) + ".csv")), Test_epioside_step_100,
    #            delimiter=',')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default='simple_catch')  # Platform-v0, Goal-v0
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds

    parser.add_argument("--max_timesteps", default=1_0000_000, type=float)  # Max time steps to run environment for
    parser.add_argument("--eval_freq", default=100, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--expl_noise", default=0.1)  # Std of Gaussian exploration noise 0.1

    parser.add_argument("--train_every", default=50, type=int)  # How often (time steps) we evaluate

    parser.add_argument("--pretrain_episodes", default=50, type=int)  # How many steps to pre-train

    parser.add_argument("--dm_hard", default=0, type=int)
    # parser.add_argument("--dm_epoch", default=50, type=int)  # how many epochs to train the dynamic model
    parser.add_argument("--dm_lr", default=1e-4, type=float)
    # parser.add_argument("--dm_batchsize", default=256, type=int)
    parser.add_argument("--dm_saveflag", default=0, type=int)
    parser.add_argument("--dm_savepath", default='mpc_model/models/dm.pth', type=str)
    parser.add_argument("--dm_loadpath", default='mpc_model/models/dm.pth', type=str)
    parser.add_argument("--dm_valflag", default=0, type=int)
    parser.add_argument("--dm_valfreq", default=500, type=int)  # freq
    parser.add_argument("--dm_valrati", default=0., type=float)
    parser.add_argument("--dm_loadmodel", default=0, type=int)
    parser.add_argument("--dm_layers", default=[64, 64], type=list)
    parser.add_argument("--dm_datasetlen", default=1e5, type=int)

    # parser.add_argument("--dm_continue", default=1, type=int)
    parser.add_argument("--dm_onepa", default=1, type=int)
    parser.add_argument("--use_terminal", default=1, type=int)
    parser.add_argument("--change_r", default=0, type=int)

    parser.add_argument("--r_hard", default=0, type=int)
    # parser.add_argument("--r_epoch", default=50, type=int)  # how many epochs to train the dynamic model
    parser.add_argument("--r_lr", default=1e-4, type=float)
    # parser.add_argument("--r_batchsize", default=256, type=int)
    parser.add_argument("--r_saveflag", default=0, type=int)
    parser.add_argument("--r_savepath", default='mpc_model/models/r.pth', type=str)
    parser.add_argument("--r_loadpath", default='mpc_model/models/r.pth', type=str)
    parser.add_argument("--r_valflag", default=0, type=int)
    parser.add_argument("--r_valfreq", default=500, type=int)  # freq
    parser.add_argument("--r_valrati", default=0., type=float)
    parser.add_argument("--r_loadmodel", default=0, type=int)
    parser.add_argument("--r_layers", default=[64, 64], type=list)
    parser.add_argument("--r_datasetlen", default=1e5, type=int)

    parser.add_argument("--c_hard", default=0, type=int)
    # parser.add_argument("--c_epoch", default=50, type=int)  # how many epochs to train the dynamic model
    parser.add_argument("--c_lr", default=1e-4, type=float)
    # parser.add_argument("--c_batchsize", default=256, type=int)
    parser.add_argument("--c_saveflag", default=0, type=int)
    parser.add_argument("--c_savepath", default='mpc_model/models/c.pth', type=str)
    parser.add_argument("--c_loadpath", default='mpc_model/models/c.pth', type=str)
    parser.add_argument("--c_valflag", default=0, type=int)
    parser.add_argument("--c_valfreq", default=500, type=int)  # freq
    parser.add_argument("--c_valrati", default=0., type=float)
    parser.add_argument("--c_loadmodel", default=0, type=int)
    parser.add_argument("--c_layers", default=[64, 64], type=list)
    parser.add_argument("--c_datasetlen", default=1e5, type=int)

    parser.add_argument("--mpc_horizon", default=10, type=int)
    parser.add_argument("--mpc_gamma", default=1., type=float)
    parser.add_argument("--mpc_popsize", default=2000, type=int)
    parser.add_argument("--mpc_num_elites", default=400, type=int)
    parser.add_argument("--mpc_patrical", default=1, type=int)
    parser.add_argument("--mpc_init_mean", default=0., type=float)
    parser.add_argument("--mpc_init_var", default=1., type=float)
    parser.add_argument("--mpc_epsilon", default=0.001, type=float)
    parser.add_argument("--mpc_alpha", default=0.1, type=float)
    parser.add_argument("--mpc_max_iters", default=1e3, type=int)
    parser.add_argument("--mpc_type", default="CEM", type=str)  # CEM, Random
    # parser.add_argument("--mpc_mode", default="hard", type=str)  # hard, dl
    parser.add_argument("--train_bs", default=256, type=int)
    parser.add_argument("--n_epochs", default=1000, type=int)

    parser.add_argument('--which_model', default="normal", type=str)  # normal, h
    parser.add_argument('--embed', default=1, type=int)  #
    parser.add_argument("--oup_param", default='determine', type=str)
    # parser.add_argument("--r_type", default='all', type=str)

    parser.add_argument('--save_dir', default="070901", type=str)
    # parser.add_argument('--save-frames', default=1, type=int)
    parser.add_argument('--visualise', default=0, type=int)
    parser.add_argument("--save_points", default=0, type=int)

    args = parser.parse_args()

    run(args)

    # for i in range(0, 3):
    #     args.seed = i
    #     run(args)
