#!/usr/bin/env python3

import os
import random

import numpy as np
import dgl
import torch 
from tensorboardX import SummaryWriter

import gym

def str2bool(s):
    if s == 'True':
        return True
    elif s == 'False':
        return False
    else:
        return bool(s)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    dgl.random.seed(seed)
    if torch.cuda.is_available():
       torch.cuda.manual_seed(seed)
       torch.cuda.manual_seed_all(seed)

def gpu_setup(use_gpu, gpu_id):
    if torch.cuda.is_available() and use_gpu:
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)  
        print('cuda available with GPU:',torch.cuda.get_device_name(0))
        device = torch.device("cuda:"+str(gpu_id))
    else:
        print('cuda not available')
        device = torch.device("cpu")

    return device

def train(args,seed,writer=None):
    
    if args.rl_model == 'sac':
        if args.active_learning == "freed_bu":
            from sac_motif_freed_bu import sac
        elif args.active_learning == "freed_pe":
            from sac_motif_freed_pe import sac
        elif args.active_learning == "per":
            from sac_motif_per import sac
        elif args.active_learning is None:
            from sac_motif import sac

    elif args.rl_model == 'ppo':
        from ppo_motif import ppo
    elif args.rl_model == 'vpg':
        from vpg_motif import vpg

    if args.rl_model in ['ppo', 'vpg']:
        from core_motif_vbased import GCNActorCritic
    else:
        from core_motif import GCNActorCritic

    workerseed = args.seed
    set_seed(workerseed)
    
    # device
    gpu_use = False
    gpu_id = None
    if args.gpu_id is not None:
        gpu_id = int(args.gpu_id)
        gpu_use = True

    device = gpu_setup(gpu_use, gpu_id)

    env = gym.make('molecule-v0')
    env.init(docking_config=args.docking_config, ratios = args.ratios, reward_step_total=args.reward_step_total,is_normalize=args.normalize_adj,has_feature=bool(args.has_feature),max_action=args.max_action,min_action=args.min_action)  
    env.seed(workerseed)

    if args.rl_model == 'sac':
        model = sac(writer, args, env, actor_critic=GCNActorCritic, ac_kwargs=dict(), seed=seed, 
            steps_per_epoch=args.steps_per_epoch, epochs=args.epochs, replay_size=int(1e6), gamma=0.99, 
            # polyak=0.995, lr=args.init_lr, alpha=args.init_alpha, batch_size=args.batch_size, start_steps=128,    
            polyak=0.995, lr=args.init_lr, alpha=args.init_alpha, batch_size=args.batch_size, start_steps=args.start_steps,
            update_after=args.update_after, update_every=args.update_every, update_freq=args.update_freq, 
            expert_every=5, num_test_episodes=8, max_ep_len=args.max_action, 
            save_freq=args.save_freq, train_alpha=True, load=args.load, checkpoint=args.checkpoint, replay_checkpoint=args.replay_checkpoint)
    
    elif args.rl_model == 'ppo':
        from mpi_tools import mpi_fork
        mpi_fork(args.n_cpus)
        epochs = 200
        model = ppo(writer, args, env, actor_critic=GCNActorCritic, ac_kwargs=dict(), seed=seed, 
            steps_per_epoch=args.steps_per_epoch, epochs=200, replay_size=int(1e6), gamma=0.99, 
            polyak=0.995, lr=args.init_lr, alpha=args.init_alpha, batch_size=args.batch_size, start_steps=args.start_steps,
            update_after=args.update_after, update_every=args.update_every, update_freq=args.update_freq, 
            expert_every=5, num_test_episodes=8, max_ep_len=args.max_action, 
            save_freq=2000, train_alpha=True)

    if args.mode.find('train') != -1:
        model.train()
    if args.mode.find('gen') != -1:
        model.generate_mols(num_mols=args.num_mols, dump=True)

    env.close()

def arg_parser():
    """
    Create an empty argparse.ArgumentParser.
    """
    import argparse
    return argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

def molecule_arg_parser():
    parser = arg_parser()

    # Choose RL model
    parser.add_argument('--rl_model', type=str, default='sac') # sac, td3, ddpg

    parser.add_argument('--gpu_id', type=int, default=None)
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'gen', 'train+gen'])
    # env
    parser.add_argument('--env', type=str, help='environment name: molecule; graph', default='molecule')
    parser.add_argument('--seed', help='RNG seed', type=int, default=666)
    parser.add_argument('--num_steps', type=int, default=int(5e7))
    parser.add_argument('--pocket_id', type=int, default=0)
    
    # parser.add_argument('--dataset', type=str, default='zinc',help='caveman; grid; ba; zinc; gdb')
    # parser.add_argument('--dataset_load', type=str, default='zinc')

    parser.add_argument('--exp_root', type=str, default='.')
    parser.add_argument('--name',type=str,default='')
    
    # rewards
    # parser.add_argument('--reward_type', type=str, default='crystal')
    # parser.add_argument('--reward_target', type=float, default=0.5,help='target reward value')
    parser.add_argument('--reward_step_total', type=float, default=0.5)
    parser.add_argument('--target', type=str, default='fa7', help='fa7, parp1, 5ht1b')
    
    # # GAN
    # parser.add_argument('--gan_type', type=str, default='normal')# normal, recommend, wgan
    # parser.add_argument('--gan_step_ratio', type=float, default=1)
    # parser.add_argument('--gan_final_ratio', type=float, default=1)
    # parser.add_argument('--has_d_step', type=int, default=1)
    # parser.add_argument('--has_d_final', type=int, default=1)

    parser.add_argument('--intr_rew', type=str, default=None) # intr, mc
    parser.add_argument('--intr_rew_ratio', type=float, default=5e-1)
    
    parser.add_argument('--tau', type=float, default=1)
    
    # # Expert
    # parser.add_argument('--expert_start', type=int, default=0)
    # parser.add_argument('--expert_end', type=int, default=int(1e6))
    # parser.add_argument('--curriculum', type=int, default=0)
    # parser.add_argument('--curriculum_num', type=int, default=6)
    # parser.add_argument('--curriculum_step', type=int, default=200)
    # parser.add_argument('--supervise_time', type=int, default=4)

    # model update
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--init_lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--update_every', type=int, default=256)
    parser.add_argument('--update_freq', type=int, default=256)
    parser.add_argument('--update_after', type=int, default=2000)
    parser.add_argument('--start_steps', type=int, default=3000)
    
    # model save and load
    parser.add_argument('--save_freq', type=int, default=500)
    parser.add_argument('--load', type=str2bool, default=False)
    parser.add_argument('--checkpoint', type=str, default='')
    parser.add_argument('--replay_checkpoint', type=str, default='')
    
    # graph embedding
    parser.add_argument('--gcn_type', type=str, default='GCN')
    parser.add_argument('--gcn_aggregate', type=str, default='sum')
    parser.add_argument('--graph_emb', type=int, default=0)
    parser.add_argument('--emb_size', type=int, default=64) # default 64
    parser.add_argument('--has_residual', type=int, default=0)
    parser.add_argument('--has_feature', type=int, default=0)

    parser.add_argument('--normalize_adj', type=int, default=0)
    parser.add_argument('--bn', type=int, default=0)

    parser.add_argument('--layer_num_g', type=int, default=3)
        
    # parser.add_argument('--stop_shift', type=int, default=-3)
    # parser.add_argument('--has_concat', type=int, default=0)
        
    # parser.add_argument('--gate_sum_d', type=int, default=0)
    # parser.add_argument('--mask_null', type=int, default=0)

    # action
    # parser.add_argument('--is_conditional', type=int, default=0) 
    # parser.add_argument('--conditional', type=str, default='low')
    parser.add_argument('--max_action', type=int, default=4) 
    parser.add_argument('--min_action', type=int, default=1) 

    # SAC
    parser.add_argument('--target_entropy', type=float, default=1.)
    parser.add_argument('--init_alpha', type=float, default=1.)
    parser.add_argument('--desc', type=str, default='ecfp') # ecfp
    parser.add_argument('--init_pi_lr', type=float, default=1e-4)
    parser.add_argument('--init_q_lr', type=float, default=1e-4)
    parser.add_argument('--init_alpha_lr', type=float, default=5e-4)
    parser.add_argument('--alpha_max', type=float, default=20.)
    parser.add_argument('--alpha_min', type=float, default=.05)

    # MC dropout
    parser.add_argument('--active_learning', type=str, default=None) # "mc", "per", None
    parser.add_argument('--dropout', type=float, default=0.3)
    parser.add_argument('--n_samples', type=int, default=5)

    # On-policy
    parser.add_argument('--n_cpus', type=int, default=1)
    parser.add_argument('--steps_per_epoch', type=int, default=500)
    parser.add_argument('--epochs', type=int, default=100)

    # Docking
    parser.add_argument('--exhaustiveness', type=int, default=1)
    parser.add_argument('--num_modes', type=int, default=10)
    parser.add_argument('--num_sub_proc', type=int, default=10)
    parser.add_argument('--n_conf', type=int, default=1)
    parser.add_argument('--error_val', type=float, default=99.9)

    parser.add_argument('--local_rank', type=int, default=0)

    parser.add_argument('--num_mols', type=int, default=1000)

    return parser

def main():
    os.chdir(os.path.dirname(os.path.realpath(__file__)))
    args = molecule_arg_parser().parse_args()
    exp_dir = os.path.join(args.exp_root, args.name)
    args.exp_dir = exp_dir
    print(args)

    docking_config = dict()
    
    assert args.target in ['fa7', 'parp1', '5ht1b_corrupted', '5ht1b', 'usp7', 'abl1', 'fkb1a', 
                           'lck', 'drd3', 'nram', 'bace1', 'thrb', 'reni', 'prgr', 'sahh', 
                           'urok', 'casp3', 'aa2ar', 'ace', 'aces', 'ada', 'ada17', 'adrb1', 
                           'adrb2', 'akt1', 'akt2'], "Wrong target type"
    if args.target == 'fa7':
        box_center = (10.131, 41.879, 32.097)
        box_size = (20.673, 20.198, 21.362)
    elif args.target == 'parp1':
        box_center = (26.413, 11.282, 27.238)
        box_size = (18.521, 17.479, 19.995)
    elif args.target == '5ht1b_corrupted':
        box_center = (-26.602, 5.277, 17.898)
        box_size = (22.5, 22.5, 22.5)
    elif args.target == '5ht1b':
        box_center = (-4.744, -6.133, 49.087)
        box_size = (13.859, 16.393, 13.166)
    elif args.target == 'usp7':
        if args.pocket_id == 0:
            box_center = (2.860, 4.819, 92.848)
            box_size = (17.112, 17.038, 14.958)
        elif args.pocket_id == 1:
            box_center = (27.413, 1.55, 29.902)
            box_size = (16.221, 16.995, 17.858)
    elif args.target == 'abl1':
        box_center = (16.496, 14.747, 3.999)
        box_size = (14.963, 8.151, 5.892)
    elif args.target == 'fkb1a':
        box_center = (-35.137, 39.04, 32.495)
        box_size = (8.453, 13.483, 8.112)
    elif args.target == 'lck':
        box_center = (25.903, 38.689, 84.247)
        box_size = (15.643, 14.847, 11.009)
    elif args.target == 'drd3':
        box_center = (8.376, 21.355, 23.542)
        box_size = (12.568, 16.409, 11.768)
    elif args.target == 'nram':
        box_center = (31.892, -9.637, 64.004)
        box_size = (13.350, 13.799, 12.355)
    elif args.target == 'bace1':
        box_center = (24.351, 12.447, 22.695)
        box_size = (16.166, 12.445, 11.690)
    elif args.target == 'thrb':
        box_center = (17.236, -12.412, 21.805)
        box_size = (19.476, 10.785, 15.041)
    elif args.target == 'reni':
        box_center = (-8.134, -14.897, -9.606)
        box_size = (18.09, 18.028, 13.46)
    elif args.target == 'prgr':
        box_center = (-3.148, -8.529, 24.597)
        box_size = (10.814, 13.208, 17.051)
    elif args.target == 'sahh':
        box_center = (46.916, -19.937, 102.612)
        box_size = (13.213, 11.925, 12.324)
    elif args.target == 'urok':
        box_center = (22.480, 18.194, 35.305)
        box_size = (15.291, 11.693, 13.631)
    elif args.target == 'casp3':
        box_center = (34.900, 35.518, 32.738)
        box_size = (9.851, 17.590, 15.041)
    elif args.target == 'ace':
        box_center = (45.273, 45.822, 45.149)
        box_size = (13.37 , 15.604, 19.619)
    elif args.target == 'aces':
        box_center = (4.656, 68.843, 65.883)
        box_size = (11.674, 13.377, 13.742)
    elif args.target == 'ada':
        box_center = (50.309, 55.878, 20.114)
        box_size = (12.465, 14.72, 13.097)
    elif args.target == 'ada17':
        box_center = (43.786, 29.154, 2.555)
        box_size = (14.295, 14.009, 11.417)
    elif args.target == 'adrb1':
        box_center = (26.372, 5.064, 1.884)
        box_size = (12.706, 11.004, 15.455)
    elif args.target == 'adrb2':
        box_center = (2.292, 4.772, 51.381)
        box_size = (11.007, 14.483, 12.627)
    elif args.target == 'akt1':
        box_center = (5.912,  2.767, 17.236)
        box_size = (12.094, 11.619, 11.643)
    elif args.target == 'akt2':
        box_center = (22.499, -18.735, 8.479)
        box_size = (11.548, 16.804, 16.464)
    elif args.target == 'aa2ar':
        box_center = (-8.907, -7.141, 55.526)
        box_size = (13.356, 12.255, 20.929)

    docking_config['receptor_file'] = os.path.join('ReLeaSE_Vina', 'docking', args.target, 'receptor.pdbqt')
    docking_config['vina_program'] = 'bin/qvina02'
    docking_config['box_center'] = box_center
    docking_config['box_size'] = box_size
    docking_config['temp_dir'] = os.path.join(exp_dir, 'tmp')
    docking_config['exhaustiveness'] = args.exhaustiveness
    docking_config['num_sub_proc'] = args.num_sub_proc
    docking_config['num_modes'] = args.num_modes
    docking_config['timeout_gen3d'] = None
    docking_config['timeout_dock'] = None
    docking_config['seed'] = args.seed
    docking_config['n_conf'] = args.n_conf
    docking_config['error_val'] = args.error_val 

    ratios = dict()
    ratios['logp'] = 0
    ratios['qed'] = 0
    ratios['sa'] = 0
    ratios['mw'] = 0
    ratios['filter'] = 0
    ratios['docking'] = 1

    args.docking_config = docking_config
    args.ratios = ratios
    
    if os.path.exists(exp_dir) and not args.load:
        raise ValueError(f'Experiment directory "{exp_dir}" already exist!')
    else:
        args.mol_dir = os.path.join(args.exp_dir, 'mols')
        args.model_dir = os.path.join(args.exp_dir, 'ckpt')
        args.logs_dir = os.path.join(args.exp_dir, 'logs')
        args.metrics_dir = os.path.join(args.exp_dir, 'metrics')
        args.docking_dir = os.path.join(args.exp_dir, 'docking')
        os.makedirs(args.mol_dir, exist_ok=True)
        os.makedirs(args.model_dir, exist_ok=True)
        os.makedirs(args.logs_dir, exist_ok=True)
        os.makedirs(args.metrics_dir, exist_ok=True)
        os.makedirs(args.docking_dir, exist_ok=True)

    writer = SummaryWriter(args.logs_dir)


    # device
    gpu_use = False
    gpu_id = None
    if args.gpu_id is not None:
        gpu_id = int(args.gpu_id)
        gpu_use = True
    device = gpu_setup(gpu_use, gpu_id)
    args.device = device

    if args.gpu_id is None:
        torch.set_num_threads(256)
        print(torch.get_num_threads())

    train(args,seed=args.seed,writer=writer)

if __name__ == '__main__':
    main()
