import os
import sys
import argparse

import random
import numpy as np
import torch

import gymnasium as gym
import gym_envs

from gflownet_traj_balance import GFlowNet
from gflownet_sub_traj_balance import SubTBGFlowNet
from gafn import GAFlowNet
from stable_baselines3 import PPO, SAC, TD3, DQN

import zipfile
import json

import cProfile
import pstats

def get_arguments(argv):
      parser = argparse.ArgumentParser(description='Adversarial driving maneuvers training')

      # Models
      parser.add_argument('-m', '--method', type=str, default='gfn', help='The method to use for training, options: ppo, sac, td3, dqn, gfn, gfnsub')
      parser.add_argument('-s', '--sim', type=str, default='gmm', help='The simulation environment to use, options: hypergrid, grid, gmm, bitgen, pusher, hypergrid_simple')
      parser.add_argument('-se', '--seed', type=int, default=42, help='The random seed to use')
      parser.add_argument('-d', '--discrete', action='store_true', help='Whether to use discrete action space')
      # Loader for stored models
      parser.add_argument('-md', '--model_dir', type=str, default='', help='The model to load in')
      parser.add_argument('-o', '--overwrite', action='store_true', help='Whether to overwrite the model directory')

      # Hyperparameters for training algs
      parser.add_argument('-t', '--total_iterations', type=int, default=100000, help='The total iterations to train for')
      parser.add_argument('-lr', '--learning_rate', type=float, default=5e-4, help='The learning rate')
      parser.add_argument('-b', '--batch_size', type=int, default=64, help='The batch size')

      # For PPO, SAC, TD3, DQN
      parser.add_argument('--gamma', type=float, default=0.99, help='The discount factor for RL')

      # for PPO
      parser.add_argument('-e', '--ent_coef', type=float, default=0, help='The entropy coefficient')
      parser.add_argument('-tp', '--timesteps_per_epoch', type=int, default=1024, help='The timesteps per training epoch')
      parser.add_argument('--clip_range', type=float, default=0.2, help='The clip range')

      # for SAC, TD3, DQN, and GFlowNet
      parser.add_argument('--learning_starts', type=int, default=100, help='The number of timesteps to start learning')
      parser.add_argument('--train_freq', type=int, default=5, help='The training frequency')
      parser.add_argument('--buffer_size', type=int, default=100000, help='The buffer size')
      parser.add_argument('--gradient_steps', type=int, default=10, help='The number of gradient steps')

      # for GFlowNet
      parser.add_argument('-em', '--explorative_model', type=str, default='', help='The explorative policy to load in')
      parser.add_argument('-en', '--explorative_num', type=int, default=0, help='The number sample to generate from the explorative policy')
      parser.add_argument('-te', '--temperature', type=float, default=0, help='The original temperature for searching the space')
      parser.add_argument('-ntd', '--no_decay',action = 'store_true', help='The temperatre will not decay, for understanding the effect of temperature')
      parser.add_argument('-sm', '--sample_method', type=int, default=0, help='The method to sample from the buffer, 1 for reward-prioritized, 2 for loss-prioritized, 3 for mixed')      
      parser.add_argument('-uf', '--use_filter', action='store_true', help='Whether to use the filter in loss calculation')
      parser.add_argument('-fu', '--filter_upper', type=float, default = 3, help='The upper threshold for batch filtering')
      parser.add_argument('-fl', '--filter_lower', type=float, default = 2, help='The lower threshold for batch filtering')
      parser.add_argument('-er', '--epsilon_random', type=float, default = 0, help='Initial probability of performaing epsilon random actions for better exploration')
      parser.add_argument('-iz', '--initial_z', type=float, default=0.1, help='The initial Z value for log Z')

      # For GFlowNet with SubTB objectives
      parser.add_argument('-w', '--weighting', type=str, default='geometricwithin', help='The weighting for the sub objectives')
      parser.add_argument('-l', '--lambda', type=float, default=0.9, help='The lambda for subTB')

      # For pessimistic GFN
      parser.add_argument('-pu', '--pessimistic_update', type=int, default = 0, help="How many pessimistic updates to do, corresponding to the notation N in their paper")

      # for printing the training logs
      parser.add_argument('-v', '--verbose', action='store_true', help='Whether to print out the training process')
      
      # for validation during training and saving intermediate models
      parser.add_argument('-nvs', '--num_val_samples', type=int, default=0, help='Number of samples to use for validation during training and saving intermediate models. No validation is done if set to 0.')
      
      # Neural network architecture
      parser.add_argument('--nn_hidden_sizes', type=int, nargs='+', default=[64, 64], help='Sizes of hidden layers in the neural network, e.g. 64 64 for two layers of size 64')
      parser.add_argument('--activation', type=str, default='relu', help='Activation function to use, options: relu, tanh, sigmoid, leakyr')

      parser.add_argument('-p', '--parallel', action='store_true', help='Whether to use parallel environments for training')
      parser.add_argument('-hf', '--high_fidelity', action='store_true', help='Whether to generate new trajectories using the current model in validation')

      parser.add_argument('-tr', '--temperature_rate', type=float, default = 40, help='temperature decay rate, lower values means earlier decay but slower decay rate')
      parser.add_argument('-mt', '--multiply_temperature', action='store_true', help='Whether to multiply the temperature instead of adding it')

      return parser.parse_args(argv)

def seed_all(seed):
      random.seed(seed)
      np.random.seed(seed)
      torch.backends.cudnn.deterministic = True
      torch.backends.cudnn.benchmark = False
      torch.manual_seed(seed)

if __name__ == '__main__':
      # for profiling
      # profiler = cProfile.Profile()
      # profiler.enable()
      
      hyperparameters = get_arguments(sys.argv[1:])
      hyperparameters = vars(hyperparameters)
      hyperparameters['continuous'] = not hyperparameters['discrete']
      hyperparameters['timeout_mask'] = False

      print(hyperparameters)

      seed_all(hyperparameters['seed'])

      validation_env = None
      data_env = None
      if hyperparameters['sim'] == 'pusher-mild':
            env = gym.make('gfn_challenges/pusher-mild')
            if hyperparameters['parallel']:
                  data_env = gym.make_vec('gfn_challenges/pusher-mild', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
      elif hyperparameters['sim'] == 'pusher-simple':
            hyperparameters['timeout_mask'] = True
            env = gym.make('gfn_challenges/pusher-simple')
            if hyperparameters['parallel']:
                  data_env = gym.make_vec('gfn_challenges/pusher-simple', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
      elif hyperparameters['sim'] == 'hypergrid-mild':
            env = gym.make('gfn_challenges/hypergrid-mild')
            if hyperparameters['parallel']:
                  data_env = gym.make_vec('gfn_challenges/hypergrid-mild', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
            if hyperparameters['num_val_samples'] > 0 and hyperparameters['high_fidelity']:
                  validation_env = gym.make_vec('gfn_challenges/hypergrid-v0', num_envs = min(hyperparameters['num_val_samples'], 1000), vectorization_mode='sync')
      elif hyperparameters['sim'] == 'hypergrid-simple':
            env = gym.make('gfn_challenges/hypergrid-simple')
            if hyperparameters['parallel']:
                  data_env = gym.make_vec('gfn_challenges/hypergrid-simple', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
            # Please change the num_envs based your PC configuration
            if hyperparameters['num_val_samples'] > 0 and hyperparameters['high_fidelity']:
                  validation_env = gym.make_vec('gfn_challenges/hypergrid-simple', num_envs=min(hyperparameters['num_val_samples'], 1000), vectorization_mode='sync')
      elif hyperparameters['sim'] == 'hypergrid-hard':
            env = gym.make('gfn_challenges/hypergrid-hard')
            if hyperparameters['parallel']:
                  data_env = gym.make_vec('gfn_challenges/hypergrid-hard', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
            # Please change the num_envs based your PC configuration
            if hyperparameters['num_val_samples'] > 0 and hyperparameters['high_fidelity']:
                  validation_env = gym.make_vec('gfn_challenges/hypergrid-hard', num_envs=min(hyperparameters['num_val_samples'], 1000), vectorization_mode='sync')
      elif hyperparameters['sim'] == 'gmm-hard':
            env = gym.make('gfn_challenges/gmm-hard')
            if hyperparameters['parallel']:
                  data_env = gym.make_vec('gfn_challenges/gmm-hard', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
      elif hyperparameters['sim'] == 'molecular-generation':
            env = gym.make('gfn_challenges/molecular-generation')
            if hyperparameters['parallel']:
                  data_env = gym.make_vec('gfn_challenges/molecular-generation', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
      else:
            print("Invalid simulation environment.")
            sys.exit(0)

      if hyperparameters['method'] not in ['ppo', 'sac', 'td3', 'dqn', 'gfn', 'gfnsub', 'gafn','teacher']:
            print("Invalid method.")
            sys.exit(0)
      
      if hyperparameters['model_dir'] != '' and not hyperparameters['overwrite']:
            print("Model directory is not empty, please make sure this training won't cause any data loss then add -o.")
            sys.exit(0)

      model_name = f"{hyperparameters['method']}_{hyperparameters['total_iterations']}_{hyperparameters['nn_hidden_sizes']}_{hyperparameters['activation']}_{hyperparameters['initial_z']}_{hyperparameters['seed']}"

      if hyperparameters["activation"] == "relu":
            activation_fn = torch.nn.ReLU
      elif hyperparameters["activation"] == "tanh":
            activation_fn = torch.nn.Tanh
      elif hyperparameters["activation"] == "sigmoid":
            activation_fn = torch.nn.Sigmoid
      elif hyperparameters["activation"] == "leakyr":
            activation_fn = torch.nn.LeakyReLU

      t_start = 0
      if hyperparameters['method'] == 'ppo':
            policy_kwargs = dict(activation_fn=activation_fn,
                  net_arch=dict(pi=hyperparameters['nn_hidden_sizes'], vf=hyperparameters['nn_hidden_sizes']))
            
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['ent_coef']}_{hyperparameters['clip_range']}_{hyperparameters['timesteps_per_epoch']}"
            
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}.zip"
            model = PPO("MlpPolicy", env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        n_steps = hyperparameters['timesteps_per_epoch'], \
                        ent_coef = hyperparameters['ent_coef'], \
                        clip_range = hyperparameters['clip_range'],\
                        gamma = hyperparameters['gamma'], \
                        seed = hyperparameters['seed'],\
                        tensorboard_log = log_dir, \
                        policy_kwargs = policy_kwargs, \
                        verbose = hyperparameters['verbose'])
            
            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  model.load(hyperparameters['model_dir'])
                  # load data in the hyperparameters['model_dir'], then read the num_timesteps
                  with zipfile.ZipFile(hyperparameters['model_dir'], 'r') as zip:
                        t_start = json.loads(zip.read('data'))['num_timesteps']
                        print("Resume training from time step: = ", t_start)
                  
            
      elif hyperparameters['method'] == 'sac':
            policy_kwargs = dict(activation_fn=activation_fn,
                  net_arch=dict(pi=hyperparameters['nn_hidden_sizes'], qf=[400, 300]))
            
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['buffer_size']}_{hyperparameters['train_freq']}_{hyperparameters['gradient_steps']}_{hyperparameters['learning_starts']}"
            
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}.zip"

            model = SAC("MlpPolicy", env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        gamma = hyperparameters['gamma'], \
                        seed = hyperparameters['seed'],\
                        tensorboard_log = log_dir, \
                        policy_kwargs = policy_kwargs, \
                        verbose = hyperparameters['verbose'])
            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  model.load(hyperparameters['model_dir'])
                  model.load_replay_buffer(hyperparameters['model_dir'].replace('.zip', '') + '.pkl')
                  # load data in the hyperparameters['model_dir'], then read the num_timesteps
                  with zipfile.ZipFile(hyperparameters['model_dir'], 'r') as zip:
                        t_start = json.loads(zip.read('data'))['num_timesteps']
                        print("Resume training from time step: = ", t_start)
      
      elif hyperparameters['method'] == 'td3':
            policy_kwargs = dict(activation_fn=activation_fn,
                  net_arch=dict(pi=hyperparameters['nn_hidden_sizes'], qf=[400, 300]))
            
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['buffer_size']}_{hyperparameters['train_freq']}_{hyperparameters['gradient_steps']}_{hyperparameters['learning_starts']}"
            
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}.zip"

            model = TD3("MlpPolicy", env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        gamma = hyperparameters['gamma'], \
                        seed = hyperparameters['seed'],\
                        tensorboard_log = log_dir, \
                        policy_kwargs = policy_kwargs, \
                        verbose = hyperparameters['verbose'])
            
            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  model.load(hyperparameters['model_dir'])
                  model.load_replay_buffer(hyperparameters['model_dir'].replace('.zip', '') + '.pkl')
                  # load data in the hyperparameters['model_dir'], then read the num_timesteps
                  with zipfile.ZipFile(hyperparameters['model_dir'], 'r') as zip:
                        t_start = json.loads(zip.read('data'))['num_timesteps']
                        print("Resume training from time step: = ", t_start)

      elif hyperparameters['method'] == 'dqn':
            policy_kwargs = dict(activation_fn=activation_fn,
                  net_arch=hyperparameters['nn_hidden_sizes'])
            
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['ent_coef']}_{hyperparameters['clip_range']}_{hyperparameters['timesteps_per_epoch']}"
            
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}.zip"
            model = DQN("MlpPolicy", env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        gamma = hyperparameters['gamma'], \
                        seed = hyperparameters['seed'],\
                        tensorboard_log = log_dir, \
                        policy_kwargs = policy_kwargs, \
                        verbose = hyperparameters['verbose'])
            
            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  model.load(hyperparameters['model_dir'])
                  model.load_replay_buffer(hyperparameters['model_dir'].replace('.zip', '') + '.pkl')
                  # load data in the hyperparameters['model_dir'], then read the num_timesteps
                  with zipfile.ZipFile(hyperparameters['model_dir'], 'r') as zip:
                        t_start = json.loads(zip.read('data'))['num_timesteps']
                        print("Resume training from time step: = ", t_start)

      elif hyperparameters['method'] == 'gfn':
            # load a model trained by stable-baselines3, the idea is to get the policy and then use only the actor part
            # not used in the paper
            model_path = hyperparameters['explorative_model']
            if model_path == '':
                  explorative_policy = None
                  hyperparameters['explorative_num'] = 0
            elif 'ppo' in model_path:
                  explorative_policy = PPO.load(model_path)

            elif 'sac' in model_path:
                  explorative_policy = SAC.load(model_path)

            elif 'td3' in model_path:
                  explorative_policy = TD3.load(model_path)

            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['buffer_size']}_{hyperparameters['train_freq']}_{hyperparameters['gradient_steps']}_{hyperparameters['learning_starts']}_False_{hyperparameters['explorative_num']}_{hyperparameters['temperature']}_{hyperparameters['sample_method']}_0_0_{hyperparameters['use_filter']}_{hyperparameters['pessimistic_update']}_{hyperparameters['epsilon_random']}"
            if hyperparameters['use_filter']:
                  model_name += f"_{hyperparameters['filter_upper']}_{hyperparameters['filter_lower']}"

            if hyperparameters['temperature_rate']!=40:
                  model_name += f"_{hyperparameters['temperature_rate']}"

            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}"

            # check if the model exists
            import os
            if os.path.exists(model_dir + '/forward_policy.pth'):
                  print("Model results exists, exit")
                  sys.exit(0)
            
            model = GFlowNet(env=env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        explorative_policy = explorative_policy, \
                        explorative_num = hyperparameters['explorative_num'],\
                        temperature = hyperparameters['temperature'],\
                        sample_method = hyperparameters['sample_method'], \
                        continuous = hyperparameters['continuous'],\
                        tensorboard_log = log_dir, \
                        verbose = hyperparameters['verbose'],\
                        use_filter = hyperparameters['use_filter'],\
                        hidden_sizes = hyperparameters['nn_hidden_sizes'],\
                        activation_fn = activation_fn,\
                        initial_z = hyperparameters['initial_z'],\
                        num_val_samples = hyperparameters['num_val_samples'],\
                        pessimistic_updates = hyperparameters['pessimistic_update'],\
                        model_dir=model_dir,
                        validation_env = validation_env,
                        data_env = data_env,
                        no_decay = hyperparameters['no_decay'],
                        timeout_mask= hyperparameters['timeout_mask'],
                        filter_upper = hyperparameters['filter_upper'],
                        filter_lower = hyperparameters['filter_lower'],
                        epsilon_random = hyperparameters['epsilon_random'],
                        temperature_rate= hyperparameters['temperature_rate'],
                        multiply_temperature = hyperparameters['multiply_temperature']
                    )
            
            t_start, i_start, e_start = 0, 0, 0

            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  t_start, i_start, e_start = model.load(hyperparameters['model_dir'], True)
                  print("Resume training from time step: = ", t_start)
                  model.load_replay_buffer(hyperparameters['model_dir'] + '.pkl')

      elif hyperparameters['method'] == 'gfnsub':
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['buffer_size']}_{hyperparameters['train_freq']}_{hyperparameters['gradient_steps']}_{hyperparameters['learning_starts']}_{hyperparameters['temperature']}_{hyperparameters['sample_method']}_0_0_{hyperparameters['use_filter']}_{hyperparameters['lambda']}_{hyperparameters['weighting']}_{hyperparameters['pessimistic_update']}_{hyperparameters['epsilon_random']}"
            if hyperparameters['use_filter']:
                  model_name += f"_{hyperparameters['filter_upper']}_{hyperparameters['filter_lower']}"

            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            import os
            if os.path.exists(model_dir + '/forward_policy.pth'):
                  print("Model results exists, exit")
                  sys.exit(0)

            model = SubTBGFlowNet(env=env, learning_rate = hyperparameters['learning_rate'], \
                         batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        temperature = hyperparameters['temperature'],\
                        sample_method = hyperparameters['sample_method'], \
                        continuous = hyperparameters['continuous'],\
                        tensorboard_log = log_dir, \
                        verbose = hyperparameters['verbose'],\
                        use_filter = hyperparameters['use_filter'],\
                        lamda = hyperparameters['lambda'],\
                        hidden_sizes = hyperparameters['nn_hidden_sizes'],\
                        activation_fn = activation_fn, \
                        initial_z = hyperparameters['initial_z'],\
                        num_val_samples = hyperparameters['num_val_samples'],\
                        pessimistic_updates = hyperparameters['pessimistic_update'],\
                        model_dir=model_dir,
                        validation_env = validation_env,
                        data_env = data_env,
                        no_decay = hyperparameters['no_decay'],
                        timeout_mask= hyperparameters['timeout_mask'],
                        filter_upper = hyperparameters['filter_upper'],
                        filter_lower = hyperparameters['filter_lower'],
                        epsilon_random = hyperparameters['epsilon_random']
                        # multiply_temperature = hyperparameters['multiply_temperature']
                    )
            
            t_start, i_start, e_start = 0, 0, 0

            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  t_start, i_start, e_start = model.load(hyperparameters['model_dir'], True)
                  print("Resume training from time step: = ", t_start)
                  model.load_replay_buffer(hyperparameters['model_dir'] + '.pkl')

      elif hyperparameters['method'] == 'gafn':
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['buffer_size']}_{hyperparameters['train_freq']}_{hyperparameters['gradient_steps']}_{hyperparameters['learning_starts']}_{hyperparameters['sample_method']}_{hyperparameters['epsilon_random']}"
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            import os
            if os.path.exists(model_dir + '/forward_policy.pth'):
                  print("Model results exists, exit")
                  sys.exit(0)
            
            # check if the model exists
            model = GAFlowNet(env=env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        sample_method = hyperparameters['sample_method'], \
                        continuous = hyperparameters['continuous'],\
                        tensorboard_log = log_dir, \
                        verbose = hyperparameters['verbose'],\
                        hidden_sizes = hyperparameters['nn_hidden_sizes'],\
                        activation_fn = activation_fn,\
                        initial_z = hyperparameters['initial_z'],\
                        num_val_samples = hyperparameters['num_val_samples'],\
                        model_dir=model_dir,
                        validation_env = validation_env,
                        data_env = data_env,
                        temperature = hyperparameters['temperature'],
                        timeout_mask= hyperparameters['timeout_mask'],
                        filter_upper = hyperparameters['filter_upper'],
                        filter_lower = hyperparameters['filter_lower'],
                        epsilon_random = hyperparameters['epsilon_random']
                        # multiply_temperature = hyperparameters['multiply_temperature']
                    )
            
            t_start, i_start, e_start = 0, 0, 0

            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  t_start, i_start, e_start = model.load(hyperparameters['model_dir'], True)
                  print("Resume training from time step: = ", t_start)
                  model.load_replay_buffer(hyperparameters['model_dir'] + '.pkl')

      elif hyperparameters['method'] == 'teacher':
            # Import the adaptive teacher implementation
            from gflownet_adaptive_teacher import AdaptiveTeacherGFlowNet
            
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['buffer_size']}_{hyperparameters['train_freq']}_{hyperparameters['gradient_steps']}_{hyperparameters['learning_starts']}_{hyperparameters['temperature']}_{hyperparameters['sample_method']}_{hyperparameters['epsilon_random']}"
            
            if hyperparameters['use_filter']:
                  model_name += f"_{hyperparameters['filter_upper']}_{hyperparameters['filter_lower']}"

            if hyperparameters['temperature_rate']!=40:
                  model_name += f"_{hyperparameters['temperature_rate']}"
            
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            
            # Check if model already exists
            import os
            if os.path.exists(model_dir + '/forward_policy.pth'):
                  print("Model results exists, exit")
                  sys.exit(0)
            
            # Create adaptive teacher model
            model = AdaptiveTeacherGFlowNet(env=env, 
                        learning_rate = hyperparameters['learning_rate'],
                        batch_size = hyperparameters['batch_size'], 
                        buffer_size = hyperparameters['buffer_size'],
                        train_freq = hyperparameters['train_freq'],
                        gradient_steps = hyperparameters['gradient_steps'],
                        learning_starts = hyperparameters['learning_starts'],
                        temperature = hyperparameters['temperature'],
                        sample_method = hyperparameters['sample_method'],
                        continuous = hyperparameters['continuous'],
                        tensorboard_log = log_dir,
                        verbose = hyperparameters['verbose'],
                        use_filter = hyperparameters['use_filter'],
                        hidden_sizes = hyperparameters['nn_hidden_sizes'],
                        activation_fn = activation_fn,
                        initial_z = hyperparameters['initial_z'],
                        num_val_samples = hyperparameters['num_val_samples'],
                        pessimistic_updates = hyperparameters['pessimistic_update'],
                        model_dir=model_dir,
                        validation_env = validation_env,
                        data_env = data_env,
                        no_decay = hyperparameters['no_decay'],
                        timeout_mask= hyperparameters['timeout_mask'],
                        filter_upper = hyperparameters['filter_upper'],
                        filter_lower = hyperparameters['filter_lower'],
                        epsilon_random = hyperparameters['epsilon_random']
                    )
            
            t_start, i_start, e_start = 0, 0, 0

            if hyperparameters['model_dir'] != '':
                  print("Loading teacher model from ", hyperparameters['model_dir'])
                  t_start, i_start, e_start = model.load(hyperparameters['model_dir'], True)
                  print("Resume teacher training from time step: = ", t_start)
                  model.load_replay_buffer(hyperparameters['model_dir'] + '.pkl')

      # create log_dir if not exist
      if not os.path.exists(log_dir):
            os.makedirs(log_dir)

      # save hyperparameters
      with open(log_dir + '/hyperparameters.json', 'w') as f:
            json.dump(hyperparameters, f)
      # model directory
      if 'gfn' not in hyperparameters['method'] and hyperparameters["method"] != 'gafn' and hyperparameters["method"] != 'teacher':
            model.learn(hyperparameters['total_iterations'] - t_start, reset_num_timesteps=False) # for the models in stable-baselines3
      else:
            model.learn(hyperparameters['total_iterations'], t_start, i_start, e_start) # for our GFN model
      model.save(model_dir)
      if hyperparameters['method'] != 'ppo':
            model.save_replay_buffer(model_dir.replace('.zip', '') + '.pkl')

      # profiler.disable()
      # profiler.dump_stats("profile_results.prof")
      # stats = pstats.Stats(profiler)
      # stats.sort_stats(pstats.SortKey.TIME)  # Sort by time, can be changed to 'cumulative' or 'calls'
      # stats.print_stats(20)  # Print the top 20 results
      
      