import numpy as np
import os
from stable_baselines3 import PPO
import gymnasium as gym
from GameInfo import TwoPlayerGames
import argparse
import pickle
from utils import set_random_seed


def main(args):
    set_random_seed(args.seed)

    if args.game in TwoPlayerGames.keys():
        target_state = TwoPlayerGames[args.game]['target_state']
    else:
        target_state = None


    # Register Env
    # learning setting when lr is Gaussian
    gym.envs.register(
        id = 'MetaEnv',
        entry_point='MetaEnv:Meta_Environment',
        kwargs={
            'game': args.game,
            'num_players': args.num_players,
            'obj': args.obj,
            'max_steer_reward': args.max_steer_reward,
            'max_obs': 1000,
            'T': args.T,
            'lr': args.lr,
            'time_interval': args.time_interval,
            'beta': args.beta,
            'target_state': target_state,
            'init_clip_threshold': args.init_clip_threshold,
            'gamma':0.99,
            'act_dim': args.act_dim,
            'obs_dim': args.obs_dim,
            'epsilon': args.epsilon,

            'distance_type': args.distance_type,

            'model_type': args.model_type,
            'mu': args.mu,
            'sigma': args.sigma,
        },
    )
    env = gym.make('MetaEnv')

    suffix = ''
    if args.mu:
        if len(args.mu) == 1:
            mu = args.mu[0]
        else:
            mu = args.mu
        suffix += '_mu{}_sigma{}'.format(mu, args.sigma) 

    prefix = './model/{}_{}_{}'.format(args.game, args.obj, args.model_type)

    if args.epsilon:
        prefix += '_Eps{}'.format(args.epsilon)

    prefix += '/seed_{}'.format(args.seed)

    if not os.path.exists(prefix):
        os.makedirs(prefix)

    if len(args.beta) == 1:
        beta = args.beta[0]
    else:
        beta = args.beta
    model_save_path = prefix + '/PPO_steering_K{}_beta{}_T{}_lr{}'.format(args.K, beta, args.T, args.lr) + suffix
    
    print(prefix, suffix)

    model = PPO("MlpPolicy", env, policy_kwargs={'net_arch': dict(pi=[256, 256], vf=[256, 256])}, verbose=1, ent_coef=args.ent)
    env.set_model(model, model_save_path)
    
    print('start train')
    model.learn(total_timesteps=args.K, log_interval=4) 
    print('finished train')
    model.save(model_save_path + 'final')

    env.reset(options='print_bp')
    statistics = env.reset(options='return_statistics')
    with open('beta{}_mu{}_seed{}'.format(args.beta, args.mu, args.seed), 'wb') as f:
        pickle.dump(statistics, f)


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--game', type = str, default = '', help='what game to run?')
    parser.add_argument('--obj', type = str, default = 'Nash', choices=['Nash', 'MaxUtility', 'Explore', 'MinGap', 'Explore_Success'], 
                        help='what is the objective? "Nash" means steering to a fixed (Nash) policy; "MaxUtility" means steering to maximize the total utility')
    parser.add_argument('-K', type = int, default = 3000000, help='training iteration')
    parser.add_argument('--seed', type = int, default = 0, help='random seed')
    parser.add_argument('-T', type = int, default = 500, help='total time step')
    parser.add_argument('--lr', type = float, default = 0.01, help='agent learning rate')
    parser.add_argument('--time-interval', type = float, default = 0.01, help='the actual time each steering step corresponding to')
    parser.add_argument('--init-clip-threshold', type = float, nargs='+', default = 0.01, help='clipping threshold for initialization')
    parser.add_argument('--num-players', type = int, default = 2, help='number of players')

    parser.add_argument('--algo', type = str, default = 'PPO', help='training algorithm')
    parser.add_argument('--beta', type = float, default = 0.0, nargs='+', help='weights on distance loss')
    parser.add_argument('--max-steer-reward', type = float, default = 10., help='the maximal steering reward')
    
    parser.add_argument('--distance-type', type = str, default = 'policy', choices=['policy', 'dual_variable'], help='how to compute the distance')

    parser.add_argument('--act-dim', type = int, default = 2, help='action dimension')
    parser.add_argument('--obs-dim', type = int, default = 2, help='obs dimension per agent')

    parser.add_argument('--model-type', default=None, type=str) #, choices=['Normal', 'Gaussian_lr', 'ValueAware'], help='model type')
    parser.add_argument('--mu', default = None, type = float, nargs='+', help='the mu list')
    parser.add_argument('--sigma', default = 0.3, type = float, help='the sigma list')

    parser.add_argument('--epsilon', type = float, default = 0.01, help='weights on adaptive beta')

    parser.add_argument('--ent', type = float, default = 0.0, help='weights on adaptive beta')

    args = parser.parse_args()
 
    return args


if __name__ == '__main__':
    args = get_parser()
    main(args)