import os
import torch
import gym
import argparse

from model.sac import SAC
from utils_sac.utils import set_seed, get_vocab

import warnings
warnings.filterwarnings(action='ignore')


def main():
    target_reward_mapping = {
        'parp1': 5,
        'fa7': 16,
        '5ht1b': 8,
        'braf': 12,
        'jak2': 4,
        'fexo': 3,
        'amlo': 2,
        'osim': 4,
        'peri': 2,
        'rano': 4,
        'sita': 4,
        'zale': 2
    }
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--gpu_id', type=int, default=0)
    parser.add_argument('-s', '--seed', type=int, default=0)
    parser.add_argument('-t', '--target', type=str, default='parp1',
                        choices=['parp1', 'fa7', '5ht1b', 'braf', 'jak2',
                                 'fexo', 'amlo','osim', 'peri', 'rano', 'sita', 'zale'])
    parser.add_argument('-v', '--vocab_path', type=str, required=True)
    
    parser.add_argument('--num_mols', type=int, default=10000)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--start_steps', type=int, default=4000, 
                        help='When to store samples using policy generated actions,\
                              reproduce offsping, and update fragment vocabulary')
    parser.add_argument('--update_after', type=int, default=3000,
                        help='When to optimize the network')
    parser.add_argument('--update_every', type=int, default=256)
    parser.add_argument('--emb_size', type=int, default=64)
    parser.add_argument('--num_layer', type=int, default=3)
    
    parser.add_argument('--tau', type=float, default=1e-1)
    parser.add_argument('--target_entropy', type=float, default=1.)
    parser.add_argument('--init_alpha', type=float, default=1.)
    parser.add_argument('--init_pi_lr', type=float, default=1e-4)
    parser.add_argument('--init_q_lr', type=float, default=1e-4)
    parser.add_argument('--init_alpha_lr', type=float, default=5e-4)
    parser.add_argument('--alpha_max', type=float, default=20.)
    parser.add_argument('--alpha_min', type=float, default=.05)

    parser.add_argument('--population_size', type=int, default=100)
    parser.add_argument('--mutation_rate', type=float, default=0.1)
    parser.add_argument('-m', '--gib_path', type=str, required=True)
    parser.add_argument('--max_vocab_update', type=int, default=50)
    parser.add_argument('-mv', '--max_vocab_size', type=int, default=300)
    parser.add_argument('--ga_steps', type=int, default=3)

    # ===== multi-objective related hyper-parameters =====
    parser.add_argument('--pref_num', type=int, default=5)
    parser.add_argument('--reward_dim', type=int, default=None,
                        help='Number of rewards. If not specified, will be automatically set based on target')
    parser.add_argument('--mean', type=str, default='geom', choices=['arithm', 'geom'])

    # ===== homotopy dynamic beta =====
    parser.add_argument('--beta', type=float, default=0.01, help='initial beta weight')
    parser.add_argument('--beta_uplim', type=float, default=1.00, help='beta upper limit')
    parser.add_argument('--beta_tau', type=float, default=1000.0, help='time constant')
    parser.add_argument('--episode_num', type=int, default=6000, help='total episode number')
    args = parser.parse_args()
    
    if args.reward_dim is None:
        args.reward_dim = target_reward_mapping.get(args.target, 5) 
        print(f"Auto-setting reward_dim to {args.reward_dim} for target {args.target}")
    
    print(args)
    
    if args.gpu_id >= 0:
        args.device = torch.device(f'cuda:{args.gpu_id}')
    else:
        args.device = torch.device('cpu')
        torch.set_num_threads(256)
    
    if not os.path.exists('results'):
        os.makedirs('results')

    gym.envs.registration.register(id='molecule-v0', entry_point='utils_sac.env:MoleculeEnv')
    set_seed(args.seed)

    vocab = get_vocab(args.vocab_path)

    env = gym.make('molecule-v0')
    env.init(vocab=vocab, target=args.target)
    env.seed(args.seed)

    sac = SAC(args, vocab, env)
    sac.run()
    env.close()


if __name__ == '__main__':
    main()
