import argparse
import random
import os
os.environ['MUJOCO_GL'] = 'egl'

import dill
import numpy as np
import torch
from torch.utils.data import DataLoader
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.utils import get_schedule_fn
from imitation.algorithms import bc
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.util.util import make_vec_env, save_policy

from envs.metaworld import MetaWorldSawyerEnv
from preforl.layers import MetaWorldResNet


def main(args):

    device = torch.device(args.device)
    seed = args.seed
    env_name = args.env_name
    num_demos = args.num_demos
    num_epochs = args.num_epochs
    lr = args.lr
    batch_size = args.batch_size
    net_arch = args.net_arch
    resnet_mode = args.resnet_mode
    save_bc = args.save_bc
    
    rng = np.random.default_rng(seed)
    env = make_vec_env(
        env_name=f'meta_image_{env_name}',
        max_episode_steps=250,
        rng=rng,
        n_envs=1,
        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # for computing rollouts
        env_make_kwargs={
            'early_termination': True,
            'sparse': True
        }
    )

    transitions = torch.load(f'saved_datasets/expert_demo/{env_name}.pt', pickle_module=dill)[:num_demos]
    T = []
    for traj in transitions:
        length = traj.obs.shape[0]
        for i in range(length):
            T.append({'obs': traj.obs[i], 'acts': traj.acts[i]})
    transitions = DataLoader(dataset=T, batch_size=batch_size, drop_last=True, shuffle=True)

    bc_policy = ActorCriticPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        lr_schedule=get_schedule_fn(lr),
        net_arch=net_arch,
        features_extractor_class=MetaWorldResNet,
        features_extractor_kwargs={
            'mode': resnet_mode,
        },
    )

    bc_trainer = bc.BC(
        policy=bc_policy,
        observation_space=env.observation_space,
        action_space=env.action_space,
        demonstrations=transitions,
        batch_size=batch_size,
        rng=rng,
        device=device,  # BUG, fixed in imitation/util/util.py #L259
                        # th.as_tensor(array, **kwargs)
        ent_weight=0.001,
        l2_weight=0.0001,
        optimizer_cls=torch.optim.Adam,
        optimizer_kwargs={"lr": lr}
    )

    bc_trainer.train(n_epochs=num_epochs)

    rng = np.random.default_rng(seed + 1)
    eval_env = make_vec_env(
        env_name=f'meta_image_{env_name}',
        max_episode_steps=250,
        rng=rng,
        n_envs=10,
        parallel=True,
        env_make_kwargs={
            'early_termination': True, 
            'sparse': True
        },
    )

    print('-' * 50)
    bc_policy = bc_policy.to(device=device)
    episode_rewards, episode_lengths = evaluate_policy(bc_policy, eval_env, 100, return_episode_rewards=True)
    success_rewards = [e for e in episode_rewards if e > 0]
    success_rate = len(success_rewards) / len(episode_rewards)

    print("Episode rewards:", episode_rewards)
    print("Episode lengths", episode_lengths)
    print("Success rate:", success_rate)

    if save_bc:
        if resnet_mode == 'train':
            policy_path = f'saved_models/bc/{num_demos}/{env_name}_{success_rate}.zip'
        elif resnet_mode == 'eval':
            policy_path = f'saved_models/bc_freeze/{num_demos}/{env_name}_{success_rate}.zip'
        save_policy(bc_trainer.policy, policy_path)


if __name__ == "__main__":

    ENV_NAMES = [
        'assembly-v2',
        'basketball-v2',
        'box-close-v2',
        'coffee-push-v2',
        'door-open-v2',
        'drawer-open-v2',
        'hammer-v2',
        'handle-pull-v2',
        'lever-pull-v2',
        'pick-place-v2',
        'peg-insert-side-v2',
        'peg-unplug-side-v2',
        'soccer-v2',
        'push-v2',
        'push-wall-v2',
        'sweep-into-v2',
        'sweep-v2',
        'window-open-v2',
    ]

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--env_name', type=str, choices=ENV_NAMES, required=True)
    parser.add_argument('--num_demos', type=int, default=50)
    parser.add_argument('--num_epochs', type=int, default=200)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--batch_size', type=int, default=200)
    parser.add_argument('--net_arch', type=list, default=[512, 512, 512])
    parser.add_argument('--resnet_mode', type=str, default='train', choices=['train', 'eval'])
    parser.add_argument('--save_bc', action='store_true')
    parser.add_argument('--exp_name', type=str, default='train_bc')

    args = parser.parse_args()

    seed = args.seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    main(args)
