# step1——Train
# Last edit：2025.9.14
import pykep as pk
import env_creating.write_space_objects_to_env_file
from env_creating.write_space_objects_to_env_file import write_space_objects_to_env_file
from models.PPO_18multi import TrainPPO
from models.LSTM_multi_1 import TrainLSTM
from models.CNN_multi_1 import TrainCNN
from models.FC_multi_1 import TrainFC
from utils3 import generate_1_collision, read_environment, Environment, generate_collisions
from Collision_screening import apogee_perigee_screening, check_collision, node_time_screening
from datetime import datetime
import time

# start timer
start_time = time.time()


# Propagation simulation step size
PROPAGATION_STEP = 0.0005 #unit: day

# simulator initialization
env_name = 'debris_collision_T_STANdone0_agent1_strict19d'

# simulation
# collision position
position = [6890140, 200, 0]  # 单位：m

# time in mjd2000
# utc_time = datetime.strptime('16 5 2024 04:00:00', '%d %m %Y %H:%M:%S')
# epoch_start = datetime(2000, 1, 1, 12, 0, 0)
# epoch_time = (utc_time - epoch_start).total_seconds()/86400  # unit：day
epoch_time = 8901

# generate collision
env_path = "environments/ICLR_strict_multistage_19d.env"
env_outer = read_environment(env_path)
# env_outer = generate_1_collision(position, epoch_time,  epoch_time-0.04, epoch_time+0.01, distance=5)
# env_outer = generate_collisions(position, epoch_time,  epoch_time-0.04, epoch_time+0.01, quantity=10, stddev=100, distance=100)
# env_outer = Environment.load(path="environments/P_agent_env_1+1d.pkl")
initial_oscelement = env_outer.protected.osculating_elements(pk.epoch(env_outer.init_params["start_time"]))

# env_outer.save(path="environments/Q_agent_env_1d.pkl")
# write_space_objects_to_env_file(env_outer, "environments/Q_agent_env_1d.txt")

# screening
env_outer = apogee_perigee_screening(env_outer)

# dynamic screening
# pass_time_information, collision_information, collision_list = check_collision(env_outer, 3600, 60, result_disp=True)
env_outer = node_time_screening(env_outer, 3600, 60)

# state_dim = 7 * len(env_outer.debris) + 6
state_dim = 8 * 10 + 6 + 3

model = TrainPPO(env_outer, state_dim, initial_oscelement, action_std=0.003, lr=0.0003, betas=(0.9, 0.999), K_epochs=3,
                 clip_eps=0.2, gamma=0.99, PROPAGATION_STEP=PROPAGATION_STEP, max_episodes=1500, max_timesteps=10000,
                 update_timestep=1000, log_interval=5, load_path=None,
                 batch_size=300)

# model = TrainLSTM(env_outer, state_dim, initial_oscelement, action_std=0.003, lr=0.0003, betas=(0.9, 0.999), K_epochs=3,
#                  clip_eps=0.2, gamma=0.99, PROPAGATION_STEP=PROPAGATION_STEP, max_episodes=1500, max_timesteps=10000,
#                  update_timestep=1000, log_interval=5, load_path=None,
#                  batch_size=300)

# model = TrainCNN(env_outer, state_dim, initial_oscelement, action_std=0.003, lr=0.0003, betas=(0.9, 0.999), K_epochs=3,
#                  clip_eps=0.2, gamma=0.99, PROPAGATION_STEP=PROPAGATION_STEP, max_episodes=1500, max_timesteps=10000,
#                  update_timestep=1000, log_interval=5, load_path=None,
#                  batch_size=300)

# model = TrainFC(env_outer, state_dim, initial_oscelement, action_std=0.003, lr=0.0003, betas=(0.9, 0.999), K_epochs=3,
#                  clip_eps=0.2, gamma=0.99, PROPAGATION_STEP=PROPAGATION_STEP, max_episodes=1500, max_timesteps=10000,
#                  update_timestep=1000, log_interval=5, load_path=None,
#                  batch_size=300)

model.run(env_name=env_name, reconstruct_eachturn=False)    # Reset the environment every round when reconstruct_eachturn is True

# end timer
end_time = time.time()

# time consumed
elapsed_time = end_time - start_time
hours = int(elapsed_time // 3600)
minutes = int((elapsed_time % 3600) // 60)
seconds = int(elapsed_time % 60)

print(f"time-consumed: {hours}hours {minutes}min {seconds}sec")