# step2-test&show
# last edit：2025.6.6
import torch
import pykep as pk
import json
import numpy as np
import matplotlib.pyplot as plt
import time
from sympy.physics.optics import deviation

from utils3 import generate_1_collision, plot_actions, Environment, read_environment, to_serializable, generate_collisions
from datetime import datetime
from simulator8_for_test import Simulator, propagate
from scipy.interpolate import interp1d
from models.PPO_traAttention_multi_1 import Memory, Actor, Critic
# from models.PPO_18multi import Memory, Actor, Critic
# from models.LSTM_multi_1 import Memory, LSTMActor
# from models.CNN_multi_1 import Memory, CNNActor
# from models.FC_multi_1 import Memory, FCActor
from Collision_screening import apogee_perigee_screening, node_time_screening
from mpl_toolkits.mplot3d import Axes3D
from env_creating.write_space_objects_to_env_file import write_space_objects_to_env_file

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

PROPAGATION_STEP = 0.0005
# max timesteps
max_timesteps = 3000


def calculate_distances(coord):
    protected_object = coord[0, :3]
    distances = []
    other_objects = coord[1:, :3]

    for obj in other_objects:
        distance = np.linalg.norm(obj - protected_object)
        distances.append(distance)

    return distances


# read environment########################################################################
# simulator initialization
# env_name = 'debris_collision_I_random1d'

# simulation
position = [6890140, 200, 0]  # unit：m
dV_limit = 0.17 / 750 * PROPAGATION_STEP * 86400/3**0.5
ddV_limit = dV_limit * 0.1

# utc_time = datetime.strptime('16 5 2024 04:00:00', '%d %m %Y %H:%M:%S')
# epoch_start = datetime(2000, 1, 1, 12, 0, 0)
# epoch_time = (utc_time - epoch_start).total_seconds() / 86400
epoch_time = 8901

# env_outer = read_environment("environments\ICLR_strict_multistage_19d.env")
# env_outer = generate_1_collision(position, epoch_time, epoch_time - 0.04, epoch_time + 0.01, distance=10)
env_outer = generate_collisions(position, epoch_time,  epoch_time-0.04, epoch_time+0.01, quantity=10, stddev=100, distance=100)
# env_outer = Environment.load(path="environments/P_agent_env_1d-1.pkl")

# write_space_objects_to_env_file(env_outer, "environments/R_agent_env_1d.env")

initial_oscelement = env_outer.protected.osculating_elements(pk.epoch(env_outer.init_params["start_time"]))

# screening
env_outer = apogee_perigee_screening(env_outer)

# dynamic screening
env_outer = node_time_screening(env_outer, 3600, 60)

env_prime = env_outer.copy()
# env_prime.save(path="environments/P_agent_env_1d.pkl")

state_dim = 8*10+9
# state_dim = 7 * len(env_outer.debris) + 6
# ppo_model = PPO(env_outer, state_dim, initial_oscelement, lr=0.003, max_episodes=20000, max_timesteps=3000,
#                 log_interval=20, load_path='PPO_continuous_debris_collision.pth')
# ppo_model = PPO(env_outer, state_dim, initial_oscelement, action_std=0.08*0.2, lr=0.003, max_episodes=1000, max_timesteps=3000, log_interval=20)
# ppo_model.run(env_name=env_name)


# load pre-trained###########################################################################
start_time = time.perf_counter()
action_dim = 3
action_std = 0.0001*dV_limit
actor_old = Actor(action_std).to(device)

# actor_old.load_state_dict(torch.load('results/models-trained/PPO_multi_debris_collision_O_agent1_1d-3.pth'))
checkpoint = torch.load("results/models-trained/PPO_multi_debris_collision_T_PPOtra_agent1_strict-27.pth")
actor_old.load_state_dict(checkpoint['actor'])
# critic.load_state_dict(checkpoint['critic'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# step = checkpoint['step']

# evaluate mode
actor_old.eval()

# propagate and record##############################################################################
memory = Memory()
states = []
print("state_dim", state_dim)

# logging variables
time_step = 1
running_reward = 0
avg_length = 0

time_tick = 0
state = np.array(env_outer.get_state()['coord'], dtype=float)

env_start = env_outer.init_params["start_time"]
simulator_outer = Simulator(env_outer, initial_oscelement, step=PROPAGATION_STEP)

action = [0, 0, 0]
actions = []
done_probs = []
collision_probs = []
deviations = []
# smoothing
action_history = [
    torch.zeros(3, device=device) for _ in range(3)
]
# done = False
collision_now = 1e10
for t in range(max_timesteps):
    # if not done:
    time_step += 1

    # calculate collision_now
    collision_now = calculate_distances(simulator_outer.env.get_state()['coord'])

    # # Running policy_old:
    # action = select_action(states)
    # print("action:", action)

    env_tmp = env_outer.copy()
    simulator = Simulator(env_tmp, initial_oscelement, step=PROPAGATION_STEP,
                          time_now=(time_tick * PROPAGATION_STEP + env_start))
    state_outer = np.array(env_outer.get_state()['coord'], dtype=float)

    state, reward, action_gpu, action_logprob, done_prob, done_label, Fail, coll_prob, dev = simulator.run(actor_old,
                                env_outer.state['epoch'], state_outer, dV_limit, action, action_history, collision_now=collision_now)
    action = action_gpu.detach().cpu().numpy().flatten()

    states.append(state)
    actions.append(action)
    done_probs.append(done_prob.detach().cpu().item())
    collision_probs.append(coll_prob)
    deviations.append(dev)

    if simulator_outer.curr_time >= simulator_outer.end_time:
        reward += 100  # succeed
    elif done_prob.detach().cpu() < 0.5:
        simulator_outer.doact(action)
        print('action:', action)

    # simulator_outer.doact(action)


    # save in memory
    memory.store(state, reward, action_gpu, action_logprob, done_prob, done_label)
    # update
    action_history.pop(0)
    action_history.append(action_gpu.detach())

    # print("rewards: ", memory.rewards)


    # else:
    #     actions.append([0, 0, 0])
    #     states.append(simulator_outer.env.get_state()['coord'].tolist())

    simulator_outer.curr_time += PROPAGATION_STEP
    propagate(simulator_outer.env, simulator_outer.curr_time)  # to curr_time

    time_tick += 1

    if simulator_outer.curr_time >= simulator_outer.end_time:
        break
    # collision?
    elif Fail:
        print('Fail!\n')
        print(coll_prob)
        # break

avg_length += t

print('rewards:', memory.rewards)
print('sum rewards:', sum(memory.rewards))


end_time = time.perf_counter()
elapsed_time = end_time - start_time

print(f"inference time: {elapsed_time:.3f}sec")

# ##########################save test results##############################################
with open('results/test_rewards.json', 'w') as f:
    json.dump([float(r) for r in memory.rewards], f)
with open('results/test_states.json', 'w') as f:
    json.dump(to_serializable(memory.states), f)
with open('results/test_actions.json', 'w') as f:
    json.dump(actions.tolist(), f)
with open('results/test_doneprobs.json', 'w') as f:
    json.dump(done_probs, f)
with open('results/test_coll_probs.json', 'w') as f:
    json.dump(collision_probs, f)
with open('results/test_deviations.json', 'w') as f:
    json.dump(deviations, f)
f.close()