import argparse
import os
import time
import gymnasium as gym
from gymnasium.wrappers import FrameStack
import numpy as np
import torch
from PASRL.PASRL import Agent


def train_online(RL_agent, env, eval_env, args):
    evals = []
    action_smoothness_evals = []
    main_task_reward = []
    start_time = time.time()
    allow_train = False

    state, ep_finished = env.reset()[0], False
    ep_total_reward, ep_timesteps, ep_num = 0, 0, 1

    actions = np.zeros((args.history_length, env.action_space.shape[0]))
    prev_actions = np.zeros((args.history_length, env.action_space.shape[0]))
    hx = RL_agent.fixed_encoder.reset_hx()
    RL_agent.init_episode_noise()

    for t in range(int(args.max_timesteps + 1)):
        maybe_evaluate_and_print(RL_agent, eval_env, evals,action_smoothness_evals, main_task_reward, t, start_time, args)

        action, hx, nhx = RL_agent.select_action(np.array(state), prev_actions, timestep=ep_timesteps, hx=hx)

        next_state, reward, ep_finished, _, info = env.step(action)

        # get the new actions
        actions[:-1] = actions[1:]
        actions[-1] = action

        ep_total_reward += reward
        ep_timesteps += 1
        ep_finished = float(ep_finished) if ep_timesteps < args.max_episode_length else 1
        RL_agent.replay_buffer.add(np.array(state), actions, prev_actions, np.array(next_state), reward, ep_finished, hx, nhx)

        state = next_state
        prev_actions[:] = actions
        hx = nhx

        if allow_train and not args.use_checkpoints:
            RL_agent.train()

        if ep_finished:
            print(f"Total T: {t + 1} Episode Num: {ep_num} Episode T: {ep_timesteps} Reward: {ep_total_reward:.3f}")

            if allow_train and args.use_checkpoints:
                RL_agent.maybe_train_and_checkpoint(ep_timesteps, ep_total_reward)

            if t >= args.timesteps_before_training:
                allow_train = True

            state = env.reset()[0]

            # actions
            actions.fill(0)
            prev_actions.fill(0)

            ep_total_reward, ep_timesteps = 0, 0
            ep_num += 1
            hx = RL_agent.fixed_encoder.reset_hx()
            RL_agent.init_episode_noise()


def maybe_evaluate_and_print(RL_agent, eval_env, evals, evals_smoothness, main_task_reward, t, start_time, args, d4rl=False):
    if t % args.eval_freq == 0:
        print("---------------------------------------")
        print(f"Evaluation at {t} time steps")
        print(f"Total time passed: {round((time.time() - start_time) / 60., 2)} min(s)")

        total_reward = np.zeros(args.eval_eps)
        main_reward = np.zeros(args.eval_eps)
        total_steps = np.zeros(args.eval_eps, dtype=int)

        # define containers here
        total_action_smoothness = np.zeros((args.eval_eps, args.max_episode_length, eval_env.action_space.shape[0]))

        for ep in range(args.eval_eps):
            eval_state, done = eval_env.reset()[0], False
            actions_eval = np.zeros((args.history_length, env.action_space.shape[0]))
            ehx = RL_agent.checkpoint_encoder.reset_hx()

            while not done and total_steps[ep] < args.max_episode_length:
                action, _, ehx = RL_agent.select_action(np.array(eval_state), np.array(actions_eval),
                                                        use_checkpoint=args.use_checkpoints, use_exploration=False,
                                                        hx=ehx)
                eval_state, reward, done, _, info = eval_env.step(action)
                total_reward[ep] += reward
                main_reward[ep] += info["reward_forward"] + info["reward_survive"]  # reward_forward, reward_survive
                total_action_smoothness[ep, total_steps[ep]] = action
                total_steps[ep] += 1

                # update the actions
                actions_eval[:-1] = actions_eval[1:]
                actions_eval[-1] = action

        # stability
        action_smoothness = np.mean(np.abs(np.mean(np.diff(total_action_smoothness, n=2, axis=1), axis=2)))

        # score and feedback
        print(f"Average total reward over {args.eval_eps} episodes: {total_reward.mean():.3f},"
              f" Average main task reward {main_reward.mean()},"
              f" Average action smoothness {action_smoothness},"
              f"Average total steps {total_steps.mean()}")

        if d4rl:
            total_reward = eval_env.get_normalized_score(total_reward) * 100
            print(f"D4RL score: {total_reward.mean():.3f}")

        evals.append(total_reward)
        main_task_reward.append(main_reward)
        evals_smoothness.append(action_smoothness)
        np.save(f"./results/{args.file_name}", evals)
        np.save(f"./results/{args.file_name}_action_smoothness", evals_smoothness)
        np.save(f"./results/{args.file_name}_main_task_reward", main_task_reward)
        RL_agent.save(f"./results/{args.file_name}")
        print("---------------------------------------")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # RL
    parser.add_argument("--env", default="Ant-v4", type=str)  # 'HalfCheetah-v4'
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_seeds", default=1, type=int)
    parser.add_argument("--offline", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument('--use_checkpoints', default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument('--history_length', default=10, type=int)
    # Evaluation
    parser.add_argument("--timesteps_before_training", default=0, type=int)
    parser.add_argument("--eval_freq", default=5e3, type=int)
    parser.add_argument("--eval_eps", default=10, type=int)
    parser.add_argument("--max_timesteps", default=3e6, type=int)
    parser.add_argument("--max_episode_length", default=1000, type=int)
    # File
    parser.add_argument('--file_name', default="PASRL_stable_")
    parser.add_argument('--d4rl_path', default="./d4rl_datasets", type=str)
    args = parser.parse_args()

    if not os.path.exists("../current_best/results"):
        os.makedirs("../current_best/results")

    for i in range(args.num_seeds):
        # define the environments
        env = gym.make(args.env)
        eval_env = gym.make(args.env)

        # define variables
        env.reset(seed=args.seed)
        env.action_space.seed(args.seed)
        eval_env.reset(seed=args.seed + 100)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        max_action = float(env.action_space.high[0])

        RL_agent = Agent(state_dim, action_dim, max_action, args.history_length, args.offline)

        # set the seeds and file name
        args.seed += i
        args.file_name = args.file_name + str(i)

        print("---------------------------------------")
        print(f"Algorithm: PASRL_stable, Env: {args.env}, Seed: {args.seed}")
        print("---------------------------------------")

        # wrap the environments
        env = FrameStack(env, num_stack=args.history_length)
        eval_env = FrameStack(eval_env, num_stack=args.history_length)

        train_online(RL_agent, env, eval_env, args)
