import os
import argparse
from datetime import datetime
import gym
import torch
import numpy as np
import random
import time
import wandb

from environments import reacher_v3_L2_005, reacher_v3_N, reacher_pos_v3_L2_005
from environments import half_cheetah_v3_O_20, half_cheetah_v3_O_10, half_cheetah_v3_O_10_goal_vel1, half_cheetah_v3_M_10_goal_vel1, half_cheetah_v3_O_10_goal_vel5 
from environments import half_cheetah_v3_O_10_goal_vel0, half_cheetah_v3_O_10_goal_vel5_ccw1, half_cheetah_v3_O_10_goal_vel5_ccw05, half_cheetah_v3_O_10_goal_vel5_ccw01, half_cheetah_v3_O_10_goal_vel0_ccw_1
from environments import hopper_v3_M_10, hopper_v3_M_10_goal_vel3, hopper_v3_M_10_goal_vel0, hopper_v3_M_10_goal_vel0_ccw_01, hopper_v3_M_10_goal_vel3_ccw0001, hopper_v3_M_10_goal_vel0_ccw_0001, hopper_v3_M_10_goal_vel0_ccw1, hopper_v3_M_10_goal_freeze_ccw0
from environments import walker2d_v3_M_10, walker2d_v3_M_10_goal_vel3, walker2d_v3_M_10_goal_vel1, walker2d_v3_M_5_goal_vel3, walker2d_v3_M_10_goal_vel0, walker2d_v3_M_10_goal_vel3_ccw0
from environments import walker2d_v3_M_10_goal_vel3_ccw2, walker2d_v3_M_10_goal_freeze, walker2d_v3_M_10_goal_vel3_ccw05, walker2d_v3_M_10_goal_vel3_ccw1, walker2d_v3_M_10_goal_vel3_ccw2, walker2d_v3_M_10_goal_vel1_ccw01
from environments import ant_v3_O_20_goal_vel3, ant_v3_O_30_goal_vel3, ant_v3_O_20_goal_vel1, ant_v3_O_30_goal_vel1, ant_v3_L2_2_goal_vel1, ant_v3_O_30_goal_vel1_ccw_1, ant_v3_L2_2_goal_vel1_ccw_1, ant_v3_L2_2_goal_vel0_ccw_1, ant_v3_L2_2_goal_vel2_ccw0

from environments import push_v1_L2_08, pickandplace_v1_L2_08, slide_v1_L2_08, reach_v1_L2_08, slide_v1_O_001, pickandplace_v1_O_001, push_v1_N, slide_v1_N, pickandplace_v1_N
from environments import gym_BSS_3zone, gym_BSS_5zone, test_N
from environments.NSFnet.NSFnet_multiV2 import SimulatedNetworkEnv
import safety_gym

from agentsosac import SacAgent
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

def run():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_id', type=str, default='MO_half_cheetah-v0')
    parser.add_argument('--cuda', action='store_true', default=False)
    parser.add_argument('--cuda_device', type=int, default=0)
    parser.add_argument('--ver_number', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--training_steps', type=int, default=1500000)
    parser.add_argument('--eval_interval', type=int, default=10000)    
    parser.add_argument('--start_steps', type=int, default=10000)
    parser.add_argument('--model_saved_step', type=int, default=10000)
    parser.add_argument('--action_sample_number', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--augement_action_sample_number', type=int, default=100)
    parser.add_argument('--augment_ratio', type=float, default=0.2)
    parser.add_argument('--augment_ratio_decay', type=float, default=0.99)
    parser.add_argument('--augment_ratio_decay_freq', type=int, default=10000)
    parser.add_argument('--penalty_weight', type=float, default=0.2)
    parser.add_argument("--prob_id", action="store", default = "")
    parser.add_argument("--log_dir", action="store", default="tmp")
    parser.add_argument('--entropy_tuning', action='store_true', default=False)
    parser.add_argument('--eval_episode', type=int, default=10)
    parser.add_argument("--wandb-project-name", type=str, default="sb3", help="the wandb's project name")
    parser.add_argument("--wandb-entity", type=str, default="wandb_ent", help="the entity (team) of wandb's project")
    parser.add_argument("--wandb-info", type=str, default="wandb_info", help="the info of wandb's project")
    parser.add_argument("-tags", "--wandb-tags", type=str, default=[], nargs="+", help="Tags for wandb run, e.g.: -tags optimized pr-123")
    parser.add_argument('--pref', type=float, nargs='+', default=[0.9, 0.1])
    args = parser.parse_args()

    # You can define configs in the external json or yaml file.
    configs = {
        'num_steps': args.training_steps,
        'batch_size': 256,#256
        'lr': 0.0003,
        'hidden_units': [256, 256],
        'memory_size': 1e6,
        'gamma': 0.99,
        'tau': 0.005,
        'entropy_tuning': args.entropy_tuning,
        'ent_coef': 0.2,  # It's ignored when entropy_tuning=True.
        'multi_step': 1,
        'per': False,  # prioritized experience replay
        'alpha': 0.6,  # It's ignored when per=False.
        'beta': 0.4,  # It's ignored when per=False.
        'beta_annealing': 0.0001,  # It's ignored when per=False.
        'grad_clip': None,
        'updates_per_step': 1,
        'start_steps': args.start_steps,
        'log_interval': 10,
        'target_update_interval': 1,
        'eval_interval': args.eval_interval,
        'eval_episode' : args.eval_episode,
        'cuda': args.cuda,
        'seed': args.seed,
        'cuda_device': args.cuda_device,
        'augement_action_sample_number': args.augement_action_sample_number,
        'prob_id': args.prob_id,
        'model_saved_step': args.model_saved_step,
        'preference' : args.pref
    }
    env = gym.make(args.env_id)
    log_dir = os.path.join(
        args.log_dir, args.env_id,
        f'SOSAC-seed{args.seed}')
    run_name = f"{args.env_id}__seed{args.seed}__{int(time.time())}__{args.prob_id}__SOSAC__{args.pref}__{args.augement_action_sample_number}__ver{args.ver_number}__{args.wandb_info}"
    tags = args.wandb_tags
    wandb.login(key='')
    run = wandb.init(
        name=run_name,
        project=args.wandb_project_name,
        entity=args.wandb_entity,
        tags=tags,
        config=vars(args),
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        monitor_gym=True,  # auto-upload the videos of agents playing the game
        save_code=True,  # optional
    )
    agent = SacAgent(env=env, log_dir=log_dir, **configs)
    agent.run()

if __name__ == '__main__':
    run()