import argparse
import random
import json
import os
os.environ['MUJOCO_GL'] = 'egl'

import torch
torch.set_num_threads(4)
import numpy as np
import gymnasium as gym
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.utils import get_schedule_fn

from preforl.metaworld_trainer import MetaWorldTrainer
from preforl.layers import MetaWorldResNet
from envs.metaworld import MetaWorldSawyerEnv


def main(args):

    seed = int(args.seed)
    device = torch.device(args.device)
    env_name = args.env_name
    net_arch = args.net_arch
    observation_mode = args.observation_mode
    num_algo_iters = args.num_algo_iters
    num_expert_demos = args.num_expert_demos
    num_random_demos = args.num_random_demos
    PREFORL_num_samples = args.PREFORL_num_samples
    PREFORL_epochs = args.PREFORL_epochs
    PREFORL_batch_size = args.PREFORL_batch_size
    PREFORL_segment_length = args.PREFORL_segment_length
    alpha = args.alpha
    contrastive_bias = args.contrastive_bias
    lr = args.lr
    resnet_mode = args.resnet_mode

    # Init a zero policy
    if observation_mode == 'image':
        env = gym.make(f'meta_image_{env_name}', max_episode_steps=250, sparse=True, early_termination=True)
        policy = ActorCriticPolicy(
            observation_space=env.observation_space,
            action_space=env.action_space,
            lr_schedule=get_schedule_fn(lr),
            net_arch=net_arch,
            features_extractor_class=MetaWorldResNet,
            features_extractor_kwargs={'mode': resnet_mode}
        )

    trainer = MetaWorldTrainer(
        env_name=env_name,
        observation_mode=observation_mode,
        max_episode_steps=250,
        policy=policy,
        expert_path=f'saved_datasets/expert_demo/{env_name}.pt',
        random_path=f'saved_datasets/random_demo/{env_name}.pt',
        num_expert_demos=num_expert_demos,
        num_random_demos=num_random_demos,
        num_algo_iters=num_algo_iters,
        PREFORL_num_samples=PREFORL_num_samples,
        PREFORL_epochs=PREFORL_epochs,
        PREFORL_batch_size=PREFORL_batch_size,
        PREFORL_segment_length=PREFORL_segment_length,
        alpha=alpha,
        contrastive_bias=contrastive_bias,
        lr=lr,
        seed=seed,
        device=device,
    )

    trainer.train()


if __name__ == '__main__':

    ENV_NAMES = [
        'hammer-v2',
        'peg-insert-side-v2',
        'peg-unplug-side-v2',
        'soccer-v2',
        'window-open-v2',
        'sweep-into-v2',
        'sweep-v2',
        'coffee-push-v2',
        'box-close-v2',
        'door-open-v2',
        'drawer-open-v2',
        'lever-pull-v2',
        'pick-place-v2',
        'push-v2',
        'push-wall-v2',
    ]

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--env_name', type=str, choices=ENV_NAMES, required=True)
    parser.add_argument('--net_arch', type=json.loads, default=[1024, 1024, 1024])
    parser.add_argument('--observation_mode', type=str, default='image', choices=['image'])
    parser.add_argument('--num_algo_iters', type=int, default=100)
    parser.add_argument('--num_expert_demos', type=int, default=50)
    parser.add_argument('--num_random_demos', type=int, default=50)
    parser.add_argument('--PREFORL_num_samples', type=int, default=20)
    parser.add_argument('--PREFORL_epochs', type=int, default=30)
    parser.add_argument('--PREFORL_batch_size', type=int, default=20)
    parser.add_argument('--PREFORL_segment_length', type=int, default=64)
    parser.add_argument('--alpha', type=float, default=0.1)
    parser.add_argument('--contrastive_bias', type=float, default=0.25)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--resnet_mode', type=str, default='train', choices=['train'])
    parser.add_argument('--exp_name', type=str, default='train_metaworld')

    args = parser.parse_args()
    print(args)

    seed = args.seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    main(args)
