import argparse
import random
import os
os.environ['MUJOCO_GL'] = 'egl'

import torch
torch.set_num_threads(4)
import dill
import gymnasium as gym
import numpy as np
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.util.util import make_vec_env
from imitation.algorithms.bc import reconstruct_policy
from metaworld.policies import *

from envs.metaworld import MetaWorldSawyerEnv


ENV_MAPS = {
    'assembly-v2': SawyerAssemblyV2Policy,
    'basketball-v2': SawyerBasketballV2Policy,
    'box-close-v2': SawyerBoxCloseV2Policy,
    'coffee-push-v2': SawyerCoffeePushV2Policy,
    'door-open-v2': SawyerDoorOpenV2Policy,
    'drawer-open-v2': SawyerDrawerOpenV2Policy,
    'hammer-v2': SawyerHammerV2Policy,
    'handle-pull-v2': SawyerHandlePullV2Policy,
    'lever-pull-v2': SawyerLeverPullV2Policy,
    'pick-place-v2': SawyerPickPlaceV2Policy,
    'peg-insert-side-v2': SawyerPegInsertionSideV2Policy,
    'peg-unplug-side-v2': SawyerPegUnplugSideV2Policy,
    'soccer-v2': SawyerSoccerV2Policy,
    'push-v2': SawyerPushV2Policy,
    'push-wall-v2': SawyerPushWallV2Policy,
    'sweep-into-v2': SawyerSweepIntoV2Policy,
    'sweep-v2': SawyerSweepV2Policy,
    'window-open-v2': SawyerWindowOpenV2Policy,
}


class CallablePolicy:
    def __init__(self, policy, action_space, noise):
        self.policy = policy
        self.action_space = action_space
        self.noise = noise

    def __call__(self, observations, states, dones):
        acts = self.policy.get_action(observations[0])
        noise = self.action_space.sample() * self.noise
        acts = np.expand_dims(acts + noise, axis=0)
        return acts, states


def main(args):

    env_name = args.env_name
    policy_mode = args.policy_mode
    policy_path = args.policy_path
    observation_mode = args.observation_mode
    seed = args.seed
    num_demos = args.num_demos
    save_demos = args.save_demos
    exp_name = args.exp_name
    device = torch.device(args.device)
    generate_image = observation_mode == 'image'

    rng = np.random.default_rng(seed)
    env = make_vec_env(
        env_name=f'meta_image_{env_name}' if policy_mode == 'pretrained' else f'meta_{env_name}',
        max_episode_steps=250,
        rng=rng,
        n_envs=10,
        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # for computing rollouts
        parallel=True,
        env_make_kwargs={
            'early_termination': True,
            'sparse': True,
            'generate_image': generate_image,
        }
    )

    if policy_mode == 'expert':
        policy = CallablePolicy(policy=ENV_MAPS[env_name](), action_space=env.action_space, noise=0.0)
        min_episodes = 2 * num_demos
    elif policy_mode == 'pretrained':
        policy = reconstruct_policy(policy_path=policy_path, device=device)
        min_episodes = 4 * num_demos
    elif policy_mode == 'random':
        policy = None
        min_episodes = num_demos

    rollouts = rollout.rollout(
        policy,
        env,
        rollout.make_sample_until(min_timesteps=None, min_episodes=min_episodes),
        rng=rng,
        exclude_infos=False,
        # deterministic_policy=True,
    )[:min_episodes]

    if policy_mode == 'expert':
        rollouts = [traj for traj in rollouts if np.any([bool(info['success']) for info in traj.infos])]
        rollouts = rollouts[:num_demos]
    elif policy_mode == 'pretrained':
        failed_rollouts = [traj for traj in rollouts if not np.any([bool(info['success']) for info in traj.infos])]  # collect failed trajectories
        print('[failed]', len(failed_rollouts))
        failed_rollouts = failed_rollouts[:num_demos]
        succeed_rollouts = [traj for traj in rollouts if np.any([bool(info['success']) for info in traj.infos])]  # collect succeed trajectories
        print('[success]', len(succeed_rollouts))
        succeed_rollouts = succeed_rollouts[:num_demos]
        print(len(rollouts))

    # rewrite images to observation space
    if generate_image:
        for traj in rollouts:
            object.__setattr__(traj, 'obs', np.stack([info['image'] for info in traj.infos], axis=0))

    if save_demos:
        dataset_base = 'saved_datasets'
        if policy_mode == 'expert':
            dataset_name = f'{dataset_base}/expert_demo/{env_name}.pt'
            torch.save(rollouts, dataset_name, pickle_module=dill)
        elif policy_mode == 'pretrained':
            # dataset_name = f'{dataset_base}/pretrained_demo/{env_name}-failed.pt'
            dataset_name = f'{dataset_base}/discriminator/{env_name}-failed.pt'
            torch.save(failed_rollouts, dataset_name, pickle_module=dill)
            # dataset_name = f'{dataset_base}/pretrained_demo/{env_name}-succeed.pt'
            dataset_name = f'{dataset_base}/discriminator/{env_name}-succeed.pt'
            torch.save(succeed_rollouts, dataset_name, pickle_module=dill)
        elif policy_mode == 'random':
            dataset_name = f'{dataset_base}/random_demo/{env_name}.pt'
            torch.save(rollouts, dataset_name, pickle_module=dill)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name', type=str, default='box-close-v2', choices=ENV_MAPS.keys())
    parser.add_argument('--policy_mode', type=str, default='expert', choices=['expert', 'pretrained', 'random'])
    parser.add_argument('--policy_path', type=str, default=None)
    parser.add_argument('--observation_mode', type=str, default='image', choices=['image'])
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--num_demos', type=int, default=50)
    parser.add_argument('--save_demos', action='store_true')
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--exp_name', type=str, default='generate_trajectories')

    args = parser.parse_args()

    seed = args.seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    ENV_NAMES = [
        'hammer-v2',
        'peg-insert-side-v2',
        'peg-unplug-side-v2',
        'soccer-v2',
        'window-open-v2',
        'sweep-into-v2',
        'sweep-v2',
        'coffee-push-v2',
        'box-close-v2',
        'drawer-open-v2',
        'lever-pull-v2',
        'pick-place-v2',
        'push-v2',
        'push-wall-v2',
    ]
    
    main(args)
