import numpy as np
import gym
# from gym.wrappers import Monitor
# import wandb as wb
import os
import torch as th
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

from rl.utils.utils import eval_test_tasks, hypervolume, moving_average, random_weights, reward_evaluation_mo, get_tau, dsf_reward_evaluation_mo
from rl.successor_features.sf_dqn_mujoco import SFDQN
from rl.successor_features.dsf_dqn_mujoco import RDSFOLS
from rl.successor_features.gpi import GPI
from rl.successor_features.dgpi import DGPI
from rl.successor_features.reward_ols import OLS
from rl.successor_features.reward_dols import DOLS

import envs
import matplotlib.pyplot as plt
import seaborn as sns
import argparse


def run(algo):

    time = datetime.now().strftime("%Y%m%d-%H%M")
    log_dir = os.path.join('logs', 'DiscreteSwimmer-v3', args.algo, 'iqn-cvar', f'{time}')
    # log_dir = os.path.join('logs', 'DiscreteSwimmer-v3', args.algo, f'{time}')
    # Log setting.
    summary_dir = os.path.join(log_dir, 'summary')
    writer = SummaryWriter(log_dir=summary_dir)
    model_dir = os.path.join(log_dir, 'model')
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    env = gym.make("discreteswimmer-v3")
    eval_env = gym.make("discreteswimmer-v3")

    if algo != 'RDSFOLS' and algo !='WCDPI':
        agent_constructor = lambda: SFDQN(env,
                                        gamma=0.9,
                                        net_arch=[256, 256],
                                        learning_rate=1e-3,
                                        batch_size=256,
                                        initial_epsilon=0.05,
                                        final_epsilon=0.05,
                                        epsilon_decay_steps=1,
                                        per=True,
                                        min_priority=0.01,
                                        buffer_size=int(4e6),
                                        gradient_updates=1,
                                        tau=1.0,
                                        target_net_update_freq=500)
        agent = GPI(env,
                    agent_constructor,
                    project_name='DiscreteWalker2dEnv-SFOLS',
                    experiment_name=algo)

        ols = OLS(m=6, epsilon=0.001, reverse_extremum=True)
        test_tasks = random_weights(dim=6, seed=42, n=60) + ols.extrema_weights()
        max_iter = 20
        for iter in range(max_iter):
            if algo == 'SFOLS':
                w = ols.next_w()
            elif algo == 'WCPI':
                w = ols.worst_case_weight()
            elif algo == 'Random':
                w = random_weights(dim=6)
            print('next w', w)

            agent.learning_starts = agent.num_timesteps # reset epsilon exploration
            agent.learn(total_timesteps=200000,
                        writer=writer,
                        w=w,
                        eval_env=eval_env,
                        eval_freq=1000,
                        reuse_value_ind=None)

            value = reward_evaluation_mo(agent, eval_env, w, rep=5)
            remove_policies = ols.add_solution(value, w, gpi_agent=agent, env=eval_env)
            agent.delete_policies(remove_policies)

            mean_train = reward_evaluation_mo(agent, eval_env, w, rep=5, return_scalarized_value=True)
            writer.add_scalar('eval/mean_value_train_tasks', mean_train, ols.iteration) 

            returns = [reward_evaluation_mo(agent, eval_env, w, rep=5, return_scalarized_value=False) for w in test_tasks]
            print('returns', returns)
            returns_ccs = [reward_evaluation_mo(agent, eval_env, w, rep=5, return_scalarized_value=False) for w in ols.ccs_weights]
            print('returns_ccs', returns_ccs)
            mean_test = np.mean([np.dot(psi, w) for (psi, w) in zip(returns, test_tasks)], axis=0)
            writer.add_scalar('eval/mean_value_test_tasks', mean_test, ols.iteration)
            mean_test_smp = np.mean([ols.max_scalarized_value(w_test) for w_test in test_tasks])
            writer.add_scalar('eval/mean_value_test_tasks_SMP', mean_test_smp, ols.iteration) 
            writer.add_scalar('eval/hypervolume', hypervolume(np.zeros(6), ols.ccs), ols.iteration)
            writer.add_scalar('eval/hypervolume_GPI', hypervolume(np.zeros(6), returns+returns_ccs), ols.iteration)
            true_values =[ols.max_value_lp(w) for w in test_tasks]
            print('true_values', true_values)
            true_value = np.mean(true_values)
            writer.add_scalar('eval/true_value', true_value, ols.iteration)

            if ols.ended():
                print("ended at iteration", iter)
                for i in range(ols.iteration + 1, max_iter + 1):
                    writer.add_scalar('eval/mean_value_test_tasks', mean_test, i)
                    writer.add_scalar('eval/mean_value_test_tasks_SMP', mean_test_smp, i)
                    writer.add_scalar('eval/hypervolume', hypervolume(np.zeros(6), ols.ccs), i) 
                    writer.add_scalar('eval/hypervolume_GPI', hypervolume(np.zeros(6), returns+returns_ccs), i)
                break
    
    else:
        agent_constructor = lambda: RDSFOLS(env,
                                        gamma=0.9,
                                        hidden_sizes=[256, 1024],
                                        learning_rate=1e-3,
                                        batch_size=256,
                                        initial_epsilon=0.05,
                                        final_epsilon=0.05,
                                        epsilon_decay_steps=1,
                                        min_priority=0.01,
                                        buffer_size=int(4e6),
                                        gradient_updates=1,
                                        tau_target=1.0,
                                        target_net_update_freq=500)
        agent = DGPI(env,
                    agent_constructor,
                    project_name='DiscreteWalker2dEnv-DSFOLS',
                    experiment_name=algo)

        ols = DOLS(m=6, epsilon=0.001, reverse_extremum=True)
        test_tasks = random_weights(dim=6, seed=42, n=60) + ols.extrema_weights()
        max_iter = 20
        for iter in range(max_iter):
            if algo == 'RDSFOLS':
                w = ols.next_w()
            elif algo == 'WCDPI':
                w = ols.worst_case_weight()
            print('next w', w)
            agent.learning_starts = agent.num_timesteps # reset epsilon exploration
            target_fp = agent.learn(total_timesteps=200000,
                            writer=writer,
                            w=w,
                            eval_env=eval_env,
                            eval_freq=1000,
                            reuse_value_ind=None)
                
            obs = env.reset()[0]
            obs = obs.reshape(-1)
            action = env.action_space.sample()
            tau, tau_hat, presum_tau = get_tau(obs, action, tau_type='iqn', num_quantiles=32, fp=target_fp)
            value = dsf_reward_evaluation_mo(agent, eval_env, w, tau_hat, presum_tau, rep=5)
            remove_policies = ols.add_solution(value, w, gpi_agent=agent, env=eval_env, fp=target_fp)
            agent.delete_policies(remove_policies)

            returns = [dsf_reward_evaluation_mo(agent, eval_env, w, tau_hat, presum_tau, rep=5, return_scalarized_value=False) for w in test_tasks]
            returns_ccs = [dsf_reward_evaluation_mo(agent, eval_env, w, tau_hat, presum_tau, rep=5, return_scalarized_value=False) for w in ols.ccs_weights]
            mean_test = np.mean([np.dot(psi, w) for (psi, w) in zip(returns, test_tasks)], axis=0)
            writer.add_scalar('eval/mean_value_test_tasks', mean_test, ols.iteration)
            mean_test_smp = np.mean([ols.max_scalarized_value(w_test) for w_test in test_tasks])
            writer.add_scalar('eval/mean_value_test_tasks_SMP', mean_test_smp, ols.iteration) 
            writer.add_scalar('eval/hypervolume', hypervolume(np.zeros(6), ols.ccs), ols.iteration)
            writer.add_scalar('eval/hypervolume_GPI', hypervolume(np.zeros(6), returns+returns_ccs), ols.iteration)
            true_values =[ols.max_value_lp(w) for w in test_tasks]
            print('true_values', true_values)
            true_value = np.mean(true_values)
            writer.add_scalar('eval/true_value', true_value, ols.iteration)

            if ols.ended():
                print("ended at iteration", iter)
                for i in range(ols.iteration + 1, max_iter + 1):
                    writer.add_scalar('eval/mean_value_test_tasks', mean_test, i)
                    writer.add_scalar('eval/mean_value_test_tasks_SMP', mean_test_smp, i)
                        
                    writer.add_scalar('eval/hypervolume', hypervolume(np.zeros(6), ols.ccs), i) 
                    writer.add_scalar('eval/hypervolume_GPI', hypervolume(np.zeros(6), returns+returns_ccs), i)
                break
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='DiscreteWalker2dEnv experiment.')
    parser.add_argument('-algo', type=str, choices=['SFOLS', 'WCPI', 'Random', 'RDSFOLS', 'WCDPI'], default='RDSFOLS', help='Algorithm.')
    parser.add_argument('-cuda_id', type=int, default=0)
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_id)
    run(args.algo)
