# coding=utf-8
# Adapted from Ravens - Transporter Networks, Zeng et al., 2021
# https://github.com/google-research/ravens
"""Data collection script with multiple training distributions."""

import os
import numpy as np
from absl import app, flags

from ravens_torch import tasks
from ravens_torch.constants import EXPERIMENTS_DIR, ENV_ASSETS_DIR
from ravens_torch.dataset import Dataset
from ravens_torch.environments.environment import EnvironmentOODBackground, Environment


flags.DEFINE_string('assets_root', ENV_ASSETS_DIR, '')
flags.DEFINE_string('data_dir', EXPERIMENTS_DIR, '')
flags.DEFINE_bool('disp', False, '')
flags.DEFINE_bool('shared_memory', False, '')
flags.DEFINE_string('task', 'place-red-in-green-ood', '') 
flags.DEFINE_string('mode', 'test', '')
flags.DEFINE_integer('n', 300, '')

FLAGS = flags.FLAGS


def main(unused_argv):

    # Initialize different tasks for training
    if FLAGS.mode == 'train':
        
        # Initialize environments
        env1 = EnvironmentOODBackground(
            FLAGS.assets_root,
            disp=FLAGS.disp,
            shared_memory=FLAGS.shared_memory,
            hz=480,
            background_urdf='ur5/workspace_train1.urdf')
        
        env2 = EnvironmentOODBackground(
            FLAGS.assets_root,
            disp=FLAGS.disp,
            shared_memory=FLAGS.shared_memory,
            hz=480,
            background_urdf='ur5/workspace_train2.urdf')
        
        env3 = EnvironmentOODBackground(
            FLAGS.assets_root,
            disp=FLAGS.disp,
            shared_memory=FLAGS.shared_memory,
            hz=480,
            background_urdf='ur5/workspace_train3.urdf')
        
        # Initialize tasks
        task1 = tasks.names[FLAGS.task]('bowl/bowl_train1.urdf') 
        task1.mode = FLAGS.mode
        
        task2 = tasks.names[FLAGS.task]('bowl/bowl_train2.urdf') 
        task2.mode = FLAGS.mode
        
        task3 = tasks.names[FLAGS.task]('bowl/bowl_train3.urdf') 
        task3.mode = FLAGS.mode
    
        
        # Initialize scripted oracle agents
        agent1 = task1.oracle(env1)
        agent2 = task2.oracle(env2)
        agent3 = task3.oracle(env3)
        
    elif FLAGS.mode == 'test':
        
        # Initialize environment
        env_test = EnvironmentOODBackground(
            FLAGS.assets_root,
            disp=FLAGS.disp,
            shared_memory=FLAGS.shared_memory,
            hz=480,
            background_urdf='ur5/workspace_test.urdf')
        
        # Initialize task
        task_test = tasks.names[FLAGS.task]('bowl/bowl_test.urdf') 
        task_test.mode = FLAGS.mode
        
        # Initialize agent
        agent_test = task_test.oracle(env_test)
    else:
        raise ValueError(f"Invalid mode: {FLAGS.mode}. Choose 'train' or 'test'.")
    
    # Initialize dataset
    dataset = Dataset(os.path.join(
        FLAGS.data_dir, f'{FLAGS.task}-{FLAGS.mode}'))

    # Train seeds are even and test seeds are odd.
    seed = dataset.max_seed
    if seed < 0:
        seed = -1 if (FLAGS.mode == 'test') else -2

    # Collect training data from oracle demonstrations.
    while dataset.n_episodes < FLAGS.n:
        # Figure out which environment to use
        if FLAGS.mode == 'test':
            agent = agent_test
            task = task_test
            env = env_test
        else:
            # If we are in the first part, use task1
            if dataset.n_episodes < FLAGS.n // 3:
                agent = agent1
                task = task1
                env = env1
            elif dataset.n_episodes < FLAGS.n * 2 // 3:
                agent = agent2
                task = task2
                env = env2
            # If we are in the third part, use task3
            else:
                agent = agent3
                task = task3
                env = env3
        
        # Print environment and oracle demonstration
        print(f'Oracle demonstration: {dataset.n_episodes + 1}/{FLAGS.n}')
        episode, total_reward = [], 0
        seed += 2
        np.random.seed(seed)
        env.set_task(task)
        obs = env.reset()
        info = None
        reward = 0
        for _ in range(task.max_steps):
            act = agent.act(obs, info)
            # print('Acting...', act)
            episode.append((obs, act, reward, info))
            obs, reward, done, info = env.step(act)
            total_reward += reward
            print(f'Total Reward: {total_reward} Done: {done}')
            if done:
                break
        episode.append((obs, None, reward, info))

        # Only save completed demonstrations.
        if total_reward > 0.99:
            dataset.add(seed, episode)


if __name__ == '__main__':
    app.run(main)
