# env
from morl_baselines.common.pareto import get_non_pareto_dominated_inds
from morl_baselines.common.weights import equally_spaced_weights
from utils.vectorize import MOAsyncVectorEnv
import mo_gymnasium as mo_gym
import tasks.register

# algorithm
from algos import algo_dict

# utils
from utils import setSeed, cprint
from utils.logger import Logger

# base
import matplotlib.pyplot as plt
from ruamel.yaml import YAML
from copy import deepcopy
import numpy as np
import argparse
import pickle
import torch
import wandb
import time
import os

def getParser():
    parser = argparse.ArgumentParser(description='RL')
    # common
    parser.add_argument('--wandb', action='store_true', help='use wandb?')
    parser.add_argument('--test', action='store_true', help='test or train?')
    parser.add_argument('--device', type=str, default='gpu', help='gpu or cpu.')
    parser.add_argument('--gpu_idx', type=int, default=0, help='GPU index.')
    parser.add_argument('--model_num', type=int, default=0, help='num model.')
    parser.add_argument('--save_freq', type=int, default=int(1e6), help='# of time steps for save.')
    parser.add_argument('--wandb_freq', type=int, default=int(1e3), help='# of time steps for wandb logging.')
    parser.add_argument('--seed', type=int, default=1, help='seed number.')
    parser.add_argument('--task_cfg_path', type=str, help='cfg.yaml file location for task.')
    parser.add_argument('--algo_cfg_path', type=str, help='cfg.yaml file location for algorithm.')
    parser.add_argument('--comment', type=str, default=None, help='wandb comment saved in run name.')
    return parser

def train(args, task_cfg, algo_cfg):
    # set seed
    setSeed(args.seed)

    # set arguments
    args.n_envs = task_cfg['n_envs']
    args.n_steps = algo_cfg['n_steps']
    args.eval_freq = task_cfg['eval_freq']
    args.sample_freq = task_cfg['sample_freq']
    args.n_total_steps = task_cfg['n_total_steps']
    args.max_episode_len = task_cfg['max_episode_len']
    args.n_prefer_eval_samples = task_cfg['n_prefer_eval_samples']
    args.n_eval_episodes = task_cfg['n_eval_episodes']

    # create environments
    env_id = lambda: mo_gym.make(
        args.task_name, 
        max_episode_length=args.max_episode_len, 
        is_earlystop=task_cfg['is_earlystop']
    )
    vec_env = MOAsyncVectorEnv([env_id for _ in range(args.n_envs)])
    test_env_id = lambda: mo_gym.make(
        args.task_name, 
        max_episode_length=args.max_episode_len, 
        is_earlystop=task_cfg['is_earlystop'],
        lin_vel_cmd_range=[1.0, 1.0],
        ang_vel_cmd_range=[0.0, 0.0],
    )
    test_vec_env = MOAsyncVectorEnv([test_env_id for _ in range(args.n_envs)])
    args.obs_dim = vec_env.single_observation_space.shape[0]
    args.action_dim = vec_env.single_action_space.shape[0]
    args.reward_dim = vec_env.single_reward_space.shape[0]
    args.cost_dim = vec_env.single_cost_space.shape[0]
    args.action_bound_min = vec_env.single_action_space.low
    args.action_bound_max = vec_env.single_action_space.high
    args.reward_names = task_cfg["rewards"]
    args.cost_names = task_cfg["costs"]
    assert len(args.reward_names) == args.reward_dim
    assert len(args.cost_names) == args.cost_dim
    args.eval_prefer_list = equally_spaced_weights(
        args.reward_dim, args.n_prefer_eval_samples)

    # declare agent
    agent_args = deepcopy(args)
    for key in algo_cfg.keys():
        agent_args.__dict__[key] = algo_cfg[key]
    agent = algo_dict[args.algo_name.lower()](agent_args)
    initial_step = agent.load(args.model_num)

    # wandb
    if args.wandb:
        project_name = '[CMORL] Legged Robot'
        wandb.init(project=project_name, config=args)
        if args.comment is not None:
            wandb.run.name = f"{args.name}/{args.comment}"
        else:
            wandb.run.name = f"{args.name}"

    # logger
    log_name_list = deepcopy(agent_args.logging['task_indep'])
    for log_name in agent_args.logging['reward_dep']:
        log_name_list += [f"{log_name}_{reward_name}" for reward_name in args.reward_names]
    for log_name in agent_args.logging['cost_dep']:
        log_name_list += [f"{log_name}_{cost_name}" for cost_name in args.cost_names]
    logger = Logger(log_name_list, f"{args.save_dir}/logs")

    # set train parameters
    reward_sums = np.zeros((args.n_envs, args.reward_dim))
    cost_sums = np.zeros((args.n_envs, args.cost_dim))
    env_cnts = np.zeros(args.n_envs)
    total_step = initial_step
    wandb_step = initial_step
    save_step = initial_step
    eval_step = initial_step

    # initialize environments
    n_actions_per_env = (args.max_episode_len*np.arange(args.n_envs)/args.n_envs).astype(int)
    observations, infos = vec_env.reset(n_actions_per_env=n_actions_per_env)
    preferences = np.tile(np.random.dirichlet(np.ones(args.reward_dim)).reshape(1, -1), (args.n_envs, 1))

    # start training
    for _ in range(int(initial_step/args.n_steps), int(args.n_total_steps/args.n_steps)):
        start_time = time.time()

        for _ in range(int(args.n_steps/args.n_envs)):
            env_cnts += 1
            total_step += args.n_envs

            # ======= collect trajectories & training ======= #
            actions = agent.getAction(observations, preferences, False)
            observations, rewards, terminates, truncates, infos = vec_env.step(actions)
            costs = rewards[..., args.reward_dim:]
            rewards = rewards[..., :args.reward_dim]

            reward_sums += rewards
            cost_sums += costs
            temp_fails = []
            temp_dones = []
            temp_observations = []

            for env_idx in range(args.n_envs):
                fail = (not truncates[env_idx]) and terminates[env_idx]
                done = terminates[env_idx] or truncates[env_idx]
                temp_observations.append(
                    infos['final_observation'][env_idx] 
                    if done else observations[env_idx])
                temp_fails.append(fail)
                temp_dones.append(done)

                if done:
                    eplen = env_cnts[env_idx]
                    if 'eplen' in logger.log_name_list: 
                        logger.write('eplen', [eplen, eplen])
                    for reward_idx in range(args.reward_dim):
                        log_name = f'reward_sum_{args.reward_names[reward_idx]}'
                        if log_name in logger.log_name_list:
                            logger.write(log_name, [eplen, reward_sums[env_idx, reward_idx]])
                    for cost_idx in range(args.cost_dim):
                        log_name = f'cost_sum_{args.cost_names[cost_idx]}'
                        if log_name in logger.log_name_list: 
                            logger.write(log_name, [eplen, cost_sums[env_idx, cost_idx]])
                    reward_sums[env_idx, :] = 0
                    cost_sums[env_idx, :] = 0
                    env_cnts[env_idx] = 0

            temp_dones = np.array(temp_dones)
            temp_fails = np.array(temp_fails)
            temp_observations = np.array(temp_observations)
            agent.step(rewards, costs, temp_dones, temp_fails, temp_observations)
            # =============================================== #

            # update preferences
            if total_step % args.sample_freq == 0:
                if np.random.rand() < 0.5:
                    preference = np.zeros(args.reward_dim)
                    preference[np.random.randint(args.reward_dim)] = 1.0
                else:
                    preference = np.random.dirichlet(np.ones(args.reward_dim))
                preferences = np.tile(preference.reshape(1, -1), (args.n_envs, 1))

            # wandb logging
            if total_step - wandb_step >= args.wandb_freq and args.wandb:
                wandb_step += args.wandb_freq
                log_data = {"step": total_step}
                print_len_episode = max(int(args.wandb_freq/args.max_episode_len), args.n_envs)
                print_len_step = max(int(args.wandb_freq/args.n_steps), args.n_envs)
                for reward_idx, reward_name in enumerate(args.reward_names):
                    for log_name in agent_args.logging['reward_dep']:
                        if 'sum' in log_name:                        
                            log_data[f'{log_name}/{reward_name}'] = logger.get_avg(f'{log_name}_{reward_name}', print_len_episode)
                        else:
                            log_data[f'{log_name}/{reward_name}'] = logger.get_avg(f'{log_name}_{reward_name}', print_len_step)
                for cost_idx, cost_name in enumerate(args.cost_names):
                    for log_name in agent_args.logging['cost_dep']:
                        if 'sum' in log_name:                        
                            log_data[f'{log_name}/{cost_name}'] = logger.get_avg(f'{log_name}_{cost_name}', print_len_episode)
                        else:
                            log_data[f'{log_name}/{cost_name}'] = logger.get_avg(f'{log_name}_{cost_name}', print_len_step)
                for log_name in agent_args.logging['task_indep']:
                    if 'eplen' in log_name:                        
                        log_data[f"metric/{log_name}"] = logger.get_avg(log_name, print_len_episode)
                    else:
                        log_data[f"metric/{log_name}"] = logger.get_avg(log_name, print_len_step)
                wandb.log(log_data)
                print(log_data)

            # save
            if total_step - save_step >= args.save_freq:
                save_step += args.save_freq
                agent.save(total_step)
                logger.save()

            # evaluation
            if total_step - eval_step >= args.eval_freq:
                eval_step += args.eval_freq

                scores_list = np.zeros((len(args.eval_prefer_list), args.reward_dim))
                consts_list = np.zeros((len(args.eval_prefer_list), args.cost_dim))
                n_eval_repeats = int(args.n_eval_episodes/args.n_envs)

                for _ in range(n_eval_repeats):
                    for prefer_idx, preference in enumerate(args.eval_prefer_list):
                        scores = np.zeros((args.n_envs, args.reward_dim))
                        consts = np.zeros((args.n_envs, args.cost_dim))
                        dones = np.zeros(args.n_envs)

                        eval_preferences = np.tile(preference, (args.n_envs, 1))
                        observations, infos = test_vec_env.reset()

                        while np.sum(dones) < args.n_envs:
                            actions = agent.getAction(observations, eval_preferences, True)
                            observations, rewards, terminates, truncates, infos = test_vec_env.step(actions)
                            costs = rewards[..., args.reward_dim:]
                            rewards = rewards[..., :args.reward_dim]

                            for env_idx in range(args.n_envs):
                                if dones[env_idx] == 0:
                                    scores[env_idx] += rewards[env_idx]
                                    consts[env_idx] += costs[env_idx]
                                    dones[env_idx] = (terminates[env_idx] or truncates[env_idx])

                        scores_list[prefer_idx, :] += np.mean(scores, axis=0)
                        consts_list[prefer_idx, :] += np.mean(consts, axis=0)

                scores_list /= n_eval_repeats
                consts_list /= n_eval_repeats
                pareto_inds = get_non_pareto_dominated_inds(scores_list)
                pareto_scores_list = scores_list[pareto_inds]

                save_pareto_dir = f"{args.save_dir}/pareto"
                if not os.path.exists(save_pareto_dir):
                    os.makedirs(save_pareto_dir)
                with open(f"{save_pareto_dir}/{total_step}.pkl", 'wb') as f:
                    pickle.dump([scores_list, consts_list, args.eval_prefer_list], f)

                plt.scatter(scores_list[:, 0], scores_list[:, 1], c='b')
                plt.scatter(pareto_scores_list[:, 0], pareto_scores_list[:, 1], c='r')
                plt.grid()
                plt.savefig(f"{save_pareto_dir}/{total_step}.png")
                plt.close()

        # train
        if agent.readyToTrain():
            train_results = agent.train()
            for log_name in agent_args.logging['task_indep']:
                if log_name in ['fps', 'eplen']: continue
                logger.write(log_name, [args.n_steps, train_results[log_name]])
            for reward_idx, reward_name in enumerate(args.reward_names):
                for log_name in agent_args.logging['reward_dep']:
                    if log_name in ['reward_sum']: continue
                    logger.write(f"{log_name}_{reward_name}", [args.n_steps, train_results[log_name][reward_idx]])
            for cost_idx, cost_name in enumerate(args.cost_names):
                for log_name in agent_args.logging['cost_dep']:
                    if log_name in ['cost_sum', 'num_cv']: continue
                    logger.write(f"{log_name}_{cost_name}", [args.n_steps, train_results[log_name][cost_idx]])

        # calculate FPS
        end_time = time.time()
        fps = args.n_steps/(end_time - start_time)
        if 'fps' in agent_args.logging['task_indep']:
            logger.write('fps', [args.n_steps, fps])

    # final save
    agent.save(total_step)
    logger.save()

    # terminate
    vec_env.close()


def test(args, task_cfg, algo_cfg):
    # create environment from the configuration file
    args.n_envs = task_cfg['n_envs']
    args.n_steps = algo_cfg['n_steps']
    args.eval_freq = task_cfg['eval_freq']
    args.sample_freq = task_cfg['sample_freq']
    args.n_total_steps = task_cfg['n_total_steps']
    args.max_episode_len = task_cfg['max_episode_len']
    args.n_prefer_eval_samples = task_cfg['n_prefer_eval_samples']
    args.n_eval_episodes = task_cfg['n_eval_episodes']

    # create environments
    env_id = lambda: mo_gym.make(
        args.task_name, 
        max_episode_length=args.max_episode_len, 
        is_earlystop=task_cfg['is_earlystop']
    )
    vec_env = MOAsyncVectorEnv([env_id for _ in range(args.n_envs)])
    args.obs_dim = vec_env.single_observation_space.shape[0]
    args.action_dim = vec_env.single_action_space.shape[0]
    args.reward_dim = vec_env.single_reward_space.shape[0]
    args.cost_dim = vec_env.single_cost_space.shape[0]
    args.action_bound_min = vec_env.single_action_space.low
    args.action_bound_max = vec_env.single_action_space.high
    args.reward_names = task_cfg["rewards"]
    args.cost_names = task_cfg["costs"]
    assert len(args.reward_names) == args.reward_dim
    assert len(args.cost_names) == args.cost_dim
    args.eval_prefer_list = equally_spaced_weights(
        args.reward_dim, args.n_prefer_eval_samples)
    
    # declare agent
    agent_args = deepcopy(args)
    for key in algo_cfg.keys():
        agent_args.__dict__[key] = algo_cfg[key]
    agent = algo_dict[args.algo_name.lower()](agent_args)
    agent.load(args.model_num)

    # rollout
    scores_list = np.zeros((len(args.eval_prefer_list), args.reward_dim))
    consts_list = np.zeros((len(args.eval_prefer_list), args.cost_dim))
    n_eval_repeats = int(args.n_eval_episodes/args.n_envs)

    for _ in range(n_eval_repeats):
        for prefer_idx, preference in enumerate(args.eval_prefer_list):
            scores = np.zeros((args.n_envs, args.reward_dim))
            consts = np.zeros((args.n_envs, args.cost_dim))
            dones = np.zeros(args.n_envs)

            eval_preferences = np.tile(preference, (args.n_envs, 1))
            observations, infos = vec_env.reset()

            while np.sum(dones) < args.n_envs:
                actions = agent.getAction(observations, eval_preferences, True)
                observations, rewards, terminates, truncates, infos = vec_env.step(actions)
                costs = rewards[..., args.reward_dim:]
                rewards = rewards[..., :args.reward_dim]

                for env_idx in range(args.n_envs):
                    if dones[env_idx] == 0:
                        scores[env_idx] += rewards[env_idx]
                        consts[env_idx] += costs[env_idx]
                        dones[env_idx] = (terminates[env_idx] or truncates[env_idx])

            scores_list[prefer_idx, :] += np.mean(scores, axis=0)
            consts_list[prefer_idx, :] += np.mean(consts, axis=0)
            print(np.mean(scores, axis=0), np.mean(consts, axis=0))

    scores_list /= n_eval_repeats
    consts_list /= n_eval_repeats
    pareto_inds = get_non_pareto_dominated_inds(scores_list)
    pareto_scores_list = scores_list[pareto_inds]

    save_pareto_dir = f"{args.save_dir}/pareto"
    if not os.path.exists(save_pareto_dir):
        os.makedirs(save_pareto_dir)

    plt.scatter(scores_list[:, 0], scores_list[:, 1], c='b')
    plt.scatter(pareto_scores_list[:, 0], pareto_scores_list[:, 1], c='r')
    plt.grid()
    plt.savefig(f"{save_pareto_dir}/eval.png")
    plt.close()

    # terminate
    vec_env.close()


if __name__ == "__main__":
    parser = getParser()
    args = parser.parse_args()

    # ==== processing args ==== #
    # load configuration file
    with open(args.task_cfg_path, 'r') as f:
        task_cfg = YAML().load(f)
    args.task_name = task_cfg['name']
    with open(args.algo_cfg_path, 'r') as f:
        algo_cfg = YAML().load(f)
    args.algo_name = algo_cfg['name']
    args.postfix = algo_cfg.get('postfix', None)
    if args.postfix is not None:
        args.name = f"{(args.task_name.lower())}_{(args.algo_name.lower())}_{(args.postfix.lower())}"
    else:
        args.name = f"{(args.task_name.lower())}_{(args.algo_name.lower())}"
    # save_dir
    args.save_dir = f"results/{args.name}/seed_{args.seed}"
    # device
    if torch.cuda.is_available() and args.device == 'gpu':
        device = torch.device(f'cuda:{args.gpu_idx}')
        cprint('[torch] cuda is used.', bold=True, color='cyan')
    else:
        device = torch.device('cpu')
        cprint('[torch] cpu is used.', bold=True, color='cyan')
    args.device = device
    # ========================= #

    if args.test:
        test(args, task_cfg, algo_cfg)
    else:
        train(args, task_cfg, algo_cfg)
