import os
import numpy as np
import gym
import torch as th
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

from rl.utils.utils import eval_test_tasks, hypervolume, moving_average, random_weights, policy_evaluation_mo, get_tau, dsf_policy_evaluation_mo
from rl.successor_features.sf_dqn import SFDQN
from rl.successor_features.dsf_dqn_reacher import RDSFOLS
from rl.successor_features.gpi import GPI
from rl.successor_features.dgpi import DGPI
from rl.successor_features.ols import OLS
from rl.successor_features.dols import DOLS

import envs
import matplotlib.pyplot as plt
import seaborn as sns
import argparse


def run(algo):
    
    time = datetime.now().strftime("%Y%m%d-%H%M")
    log_dir = os.path.join('logs', 'Reacher-v0', args.algo, 'fqf-neutral', f'{time}')
    # log_dir = os.path.join('logs', 'Reacher-v0', args.algo, f'{time}')
    # Log setting.
    summary_dir = os.path.join(log_dir, 'summary')
    writer = SummaryWriter(log_dir=summary_dir)
    model_dir = os.path.join(log_dir, 'model')
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    env = gym.make("ReacherMultiTask-v0")
    eval_env = gym.make("ReacherMultiTask-v0")

    if algo != 'RDSFOLS' and algo !='WCDPI':
        agent_constructor = lambda: SFDQN(env,
                                    gamma=0.9,
                                    net_arch=[256, 256],
                                    learning_rate=1e-3,
                                    batch_size=256,
                                    initial_epsilon=0.05,
                                    final_epsilon=0.05,
                                    epsilon_decay_steps=1,
                                    per=True,
                                    min_priority=0.01,
                                    buffer_size=int(4e6),
                                    gradient_updates=1,
                                    tau=1.0,
                                    target_net_update_freq=500)
        agent = GPI(env,
                    agent_constructor,
                    project_name='Reacher-SFOLS',
                    experiment_name=algo)

        ols = OLS(m=4, epsilon=0.001, reverse_extremum=True)
        test_tasks = random_weights(dim=4, seed=42, n=60) + ols.extrema_weights()
        max_iter = 15
        for iter in range(max_iter):
            if algo == 'SFOLS':
                w = ols.next_w()
            elif algo == 'WCPI':
                w = ols.worst_case_weight()
            elif algo == 'Random':
                w = random_weights(dim=4)
            print('next w', w)

            agent.learning_starts = agent.num_timesteps # reset epsilon exploration
            agent.learn(total_timesteps=200000,
                        writer=writer,
                        w=w,
                        eval_env=eval_env,
                        eval_freq=1000,
                        reuse_value_ind=None)

            value = policy_evaluation_mo(agent, eval_env, w, rep=5)
            remove_policies = ols.add_solution(value, w, gpi_agent=agent, env=eval_env)
            agent.delete_policies(remove_policies)

            mean_train = policy_evaluation_mo(agent, eval_env, w, rep=5, return_scalarized_value=True)
            writer.add_scalar('eval/mean_value_train_tasks', mean_train, ols.iteration) 

            returns = [policy_evaluation_mo(agent, eval_env, w, rep=5, return_scalarized_value=False) for w in test_tasks]
            returns_ccs = [policy_evaluation_mo(agent, eval_env, w, rep=5, return_scalarized_value=False) for w in ols.ccs_weights]
            mean_test = np.mean([np.dot(psi, w) for (psi, w) in zip(returns, test_tasks)], axis=0)
            writer.add_scalar('eval/mean_value_test_tasks', mean_test, ols.iteration)
            mean_test_smp = np.mean([ols.max_scalarized_value(w_test) for w_test in test_tasks])
            writer.add_scalar('eval/mean_value_test_tasks_SMP', mean_test_smp, ols.iteration) 
            writer.add_scalar('eval/hypervolume', hypervolume(np.zeros(4), ols.ccs), ols.iteration)
            writer.add_scalar('eval/hypervolume_GPI', hypervolume(np.zeros(4), returns+returns_ccs), ols.iteration)
            true_values =[ols.max_value_lp(w) for w in test_tasks]
            print('true_values', true_values)
            true_value = np.mean(true_values)
            writer.add_scalar('eval/true_value', true_value, ols.iteration)

            if ols.ended():
                print("ended at iteration", iter)
                for i in range(ols.iteration + 1, max_iter + 1):
                    writer.add_scalar('eval/mean_value_test_tasks', mean_test, i)
                    writer.add_scalar('eval/mean_value_test_tasks_SMP', mean_test_smp, i)
                    writer.add_scalar('eval/hypervolume', hypervolume(np.zeros(4), ols.ccs), i) 
                    writer.add_scalar('eval/hypervolume_GPI', hypervolume(np.zeros(4), returns+returns_ccs), i)
                break
    
    else:
        agent_constructor = lambda: RDSFOLS(env,
                                    gamma=0.9,
                                    hidden_sizes=[256, 1024],
                                    learning_rate=1e-3,
                                    batch_size=256,
                                    initial_epsilon=0.05,
                                    final_epsilon=0.05,
                                    epsilon_decay_steps=1,
                                    min_priority=0.01,
                                    buffer_size=int(4e6),
                                    gradient_updates=1,
                                    tau_target=1.0,
                                    target_net_update_freq=500)
        agent = DGPI(env,
                    agent_constructor,
                    project_name='Reacher-DSFOLS',
                    experiment_name=algo)

        ols = DOLS(m=4, epsilon=0.001, reverse_extremum=True)
        test_tasks = random_weights(dim=4, seed=42, n=60) + ols.extrema_weights()
        max_iter = 15
        for iter in range(max_iter):
            if algo == 'RDSFOLS':
                w = ols.next_w()
            elif algo == 'WCDPI':
                w = ols.worst_case_weight()
            print('next w', w)

            agent.learning_starts = agent.num_timesteps
            target_fp = agent.learn(total_timesteps=200000,
                        writer=writer,
                        w=w,
                        eval_env=eval_env,
                        eval_freq=1000,
                        reuse_value_ind=None)
                
            obs = env.reset()
            action = env.action_space.sample()
            tau, tau_hat, presum_tau = get_tau(obs, action, tau_type='fqf', num_quantiles=32, fp=target_fp)
            value = dsf_policy_evaluation_mo(agent, eval_env, w, tau_hat, presum_tau, rep=5)
            remove_policies = ols.add_solution(value, w, gpi_agent=agent, env=eval_env, fp=target_fp)
            agent.delete_policies(remove_policies)

            returns = [dsf_policy_evaluation_mo(agent, eval_env, w, tau_hat, presum_tau, rep=5, return_scalarized_value=False) for w in test_tasks]
            returns_ccs = [dsf_policy_evaluation_mo(agent, eval_env, w, tau_hat, presum_tau, rep=5, return_scalarized_value=False) for w in ols.ccs_weights]
            mean_test = np.mean([np.dot(psi, w) for (psi, w) in zip(returns, test_tasks)], axis=0)
            writer.add_scalar('eval/mean_value_test_tasks', mean_test, ols.iteration)
            mean_test_smp = np.mean([ols.max_scalarized_value(w_test) for w_test in test_tasks])
            writer.add_scalar('eval/mean_value_test_tasks_SMP', mean_test_smp, ols.iteration) 
            writer.add_scalar('eval/hypervolume', hypervolume(np.zeros(4), ols.ccs), ols.iteration)
            writer.add_scalar('eval/hypervolume_GPI', hypervolume(np.zeros(4), returns+returns_ccs), ols.iteration)
            true_values =[ols.max_value_lp(w) for w in test_tasks]
            print('true_values', true_values)
            true_value = np.mean(true_values)
            writer.add_scalar('eval/true_value', true_value, ols.iteration)

            if ols.ended():
                print("ended at iteration", iter)
                for i in range(ols.iteration + 1, max_iter + 1):
                    writer.add_scalar('eval/mean_value_test_tasks', mean_test, i)
                    writer.add_scalar('eval/mean_value_test_tasks_SMP', mean_test_smp, i)
                    writer.add_scalar('eval/hypervolume', hypervolume(np.zeros(4), ols.ccs), i) 
                    writer.add_scalar('eval/hypervolume_GPI', hypervolume(np.zeros(4), returns+returns_ccs), i)
                break


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Reacher experiment.')
    parser.add_argument('-algo', type=str, choices=['SFOLS', 'WCPI', 'Random', 'RDSFOLS', 'WCDPI'], default='RDSFOLS', help='Algorithm.')
    args = parser.parse_args()
    run(args.algo)

