import argparse
import random
import json
import os
os.environ['MUJOCO_GL'] = 'egl'

import torch
torch.set_num_threads(4)
import numpy as np
import gymnasium as gym

from preforl.mujoco_trainer import MuJoCoTrainer


def main(args):

    seed = int(args.seed)
    device = torch.device(args.device)
    env_name = args.env_name
    net_arch = args.net_arch
    num_algo_iters = args.num_algo_iters
    PREFORL_num_samples = args.PREFORL_num_samples
    PREFORL_num_epochs = args.PREFORL_num_epochs
    PREFORL_batch_size = args.PREFORL_batch_size
    PREFORL_segment_length = args.PREFORL_segment_length
    alpha = args.alpha
    contrastive_bias = args.contrastive_bias
    lr = args.lr

    trainer = MuJoCoTrainer(
        env_name=env_name,
        net_arch=net_arch,
        num_algo_iters=num_algo_iters,
        PREFORL_num_samples=PREFORL_num_samples,
        PREFORL_num_epochs=PREFORL_num_epochs,
        PREFORL_batch_size=PREFORL_batch_size,
        PREFORL_segment_length=PREFORL_segment_length,
        alpha=alpha,
        contrastive_bias=contrastive_bias,
        lr=lr,
        seed=seed,
        device=device,
    )

    trainer.train()


if __name__ == '__main__':

    ENV_NAMES = [
        'halfcheetah-medium',
        'halfcheetah-medium-expert',
        'walker2d-medium',
        'walker2d-medium-expert',
        'hopper-medium',
        'hopper-medium-expert',
    ]

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--env_name', type=str, required=True, choices=ENV_NAMES)
    parser.add_argument('--net_arch', type=json.loads, default=[1024, 1024, 1024])
    parser.add_argument('--num_algo_iters', type=int, default=500)
    parser.add_argument('--PREFORL_num_samples', type=int, default=50)
    parser.add_argument('--PREFORL_num_epochs', type=int, default=50)
    parser.add_argument('--PREFORL_batch_size', type=int, default=64)
    parser.add_argument('--PREFORL_segment_length', type=int, default=100)
    parser.add_argument('--alpha', type=float, default=0.1)
    parser.add_argument('--contrastive_bias', type=float, default=0.5)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--exp_name', type=str, default='train_mujoco')

    args = parser.parse_args()
    print(args)

    seed = args.seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    main(args)
