import argparse
import os
import time

import gymnasium as gym
import numpy as np
import torch
from utils.Low_pass_filter import LowPassFilter
from utils.PD_controller import PD
from Algorithms.TD7 import Agent
from Environment.Ant_ac import AntEnv_AC


def train_online(RL_agent, env, eval_env, args):
    evals = []
    action_smoothness_evals = []
    main_task_reward = []
    start_time = time.time()
    allow_train = False

    state, ep_finished = env.reset()[0], False
    ep_total_reward, ep_timesteps, ep_num = 0, 0, 1

    env_low_pass_filter.reset()
    env_PD_controller.reset()

    for t in range(int(args.max_timesteps + 1)):
        maybe_evaluate_and_print(RL_agent, eval_env, evals, action_smoothness_evals,
                                 main_task_reward, t, start_time, args)

        if allow_train:
            action = RL_agent.select_action(np.array(state))
        else:
            action = env.action_space.sample()

        # filter action
        action = env_low_pass_filter.apply(action)
        # PD controller
        action, _ = env_PD_controller.action(action)

        next_state, reward, ep_finished, truncated, info = env.step(action)

        ep_total_reward += reward
        ep_timesteps += 1
        ep_finished = float(ep_finished) if ep_timesteps < args.max_episode_length else 1
        RL_agent.replay_buffer.add(state, action, next_state, reward, ep_finished)

        state = next_state

        if allow_train and not args.use_checkpoints:
            RL_agent.train()

        if ep_finished:
            print(f"Total T: {t + 1} Episode Num: {ep_num} Episode T: {ep_timesteps} Reward: {ep_total_reward:.3f}")

            if allow_train and args.use_checkpoints:
                RL_agent.maybe_train_and_checkpoint(ep_timesteps, ep_total_reward)

            if t >= args.timesteps_before_training:
                allow_train = True

            state, ep_finished = env.reset()[0], False
            ep_total_reward, ep_timesteps = 0, 0
            ep_num += 1
            env_low_pass_filter.reset()
            env_PD_controller.reset()


def train_offline(RL_agent, env, eval_env, args):
    RL_agent.replay_buffer.load_D4RL(d4rl.qlearning_dataset(env))

    evals = []
    action_smoothness_evals = []
    main_task_reward = []
    start_time = time.time()

    for t in range(int(args.max_timesteps + 1)):
        maybe_evaluate_and_print(RL_agent, eval_env, evals, action_smoothness_evals, main_task_reward,
                                 t, start_time, args, d4rl=True)
        RL_agent.train()


def maybe_evaluate_and_print(RL_agent, eval_env, evals, evals_smoothness, main_task_reward, t, start_time, args, d4rl=False):
    if t % args.eval_freq == 0:
        print("---------------------------------------")
        print(f"Evaluation at {t} time steps")
        print(f"Total time passed: {round((time.time() - start_time) / 60., 2)} min(s)")

        total_reward = np.zeros(args.eval_eps)
        main_reward = np.zeros(args.eval_eps)
        total_steps = np.zeros(args.eval_eps, dtype=int)
        total_action_smoothness = np.zeros((args.eval_eps, args.max_episode_length, eval_env.action_space.shape[0]))

        for ep in range(args.eval_eps):
            state, done = eval_env.reset()[0], False
            eval_env_low_pass_filter.reset()
            eval_env_PD_controller.reset()

            while not done and total_steps[ep] < args.max_episode_length:
                action = RL_agent.select_action(np.array(state), args.use_checkpoints, use_exploration=False)
                action = eval_env_low_pass_filter.apply(action)
                action, _ = eval_env_PD_controller.action(action)
                state, reward, done, _, info = eval_env.step(action)
                total_reward[ep] += reward
                main_reward[ep] += info["reward_forward"] + info["reward_survive"]  # info["reward_run"]
                total_action_smoothness[ep, total_steps[ep]] = action
                total_steps[ep] += 1

        # stability
        action_smoothness = np.mean(np.abs(np.mean(np.diff(total_action_smoothness, n=2, axis=1), axis=2)))

        # score and feedback
        print(f"Average total reward over {args.eval_eps} episodes: {total_reward.mean():.3f},"
              f" Average main task reward {main_reward.mean()},"
              f" Average action smoothness {action_smoothness},"
              f"Average total steps {total_steps.mean()}")

        if d4rl:
            total_reward = eval_env.get_normalized_score(total_reward) * 100
            print(f"D4RL score: {total_reward.mean():.3f}")

        evals.append(total_reward)
        main_task_reward.append(main_reward)
        evals_smoothness.append(action_smoothness)
        np.save(f"./results/{args.file_name}", evals)
        np.save(f"./results/{args.file_name}_action_smoothness", evals_smoothness)
        np.save(f"./results/{args.file_name}_main_task_reward", main_task_reward)
        RL_agent.save(f"./results/{args.file_name}")
        print("---------------------------------------")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # RL
    parser.add_argument("--env", default="Ant-v4", type=str)  # Hopper, Humanoid
    parser.add_argument("--seed", default=4, type=int)
    parser.add_argument("--offline", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument('--use_checkpoints', default=True, action=argparse.BooleanOptionalAction)
    # Traditional control
    parser.add_argument("--cutoff_frequency", default=20, type=float, help="Cutoff frequency in Hz")  # Hopper, Humanoid
    parser.add_argument("--sampling_frequency", default=20, type=float, help="Based on the environment")  # Hopper, Humanoid
    parser.add_argument("--Kp", default=1, type=float)
    parser.add_argument("--Kd", default=0.05, type=float)
    # Evaluation
    parser.add_argument("--timesteps_before_training", default=25e3, type=int)
    parser.add_argument("--eval_freq", default=5e3, type=int)
    parser.add_argument("--eval_eps", default=10, type=int)
    parser.add_argument("--max_timesteps", default=3e6, type=int)  #3e6
    parser.add_argument("--max_episode_length", default=1000, type=int)
    # File
    parser.add_argument('--file_name', default="TD7_AC_PD_LPF_4")
    parser.add_argument('--d4rl_path', default="./d4rl_datasets", type=str)
    args = parser.parse_args()

    if args.offline:
        import d4rl

        d4rl.set_dataset_path(args.d4rl_path)
        args.use_checkpoints = False

    if not os.path.exists("./results"):
        os.makedirs("./results")

    env = AntEnv_AC()
    eval_env = AntEnv_AC()

    print("---------------------------------------")
    print(f"Algorithm: TD7 + AC + LPF + PD, Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")

    env.reset(seed=args.seed)
    env.action_space.seed(args.seed)
    eval_env.reset(seed=args.seed + 100)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    # filters
    env_low_pass_filter = LowPassFilter(cutoff_freq=args.cutoff_frequency,
                                        sampling_freq=args.sampling_frequency,
                                        action_dim=action_dim
                                        )

    eval_env_low_pass_filter = LowPassFilter(cutoff_freq=args.cutoff_frequency,
                                             sampling_freq=args.sampling_frequency,
                                             action_dim=action_dim
                                             )

    # PD controllers
    env_PD_controller = PD(kp=args.Kp, kd=args.Kd, action_dim=action_dim)
    eval_env_PD_controller = PD(kp=args.Kp, kd=args.Kd, action_dim=action_dim)

    RL_agent = Agent(state_dim, action_dim, max_action, args.offline)

    if args.offline:
        train_offline(RL_agent, env, eval_env, args)
    else:
        train_online(RL_agent, env, eval_env, args)
