"""Train with deceptive risk minimization (DRM) and behavior cloning (BC)."""

####################################################################################################   
from absl import app, flags
import numpy as np

from ravens_torch import agents, tasks
from ravens_torch.constants import EXPERIMENTS_DIR, ENV_ASSETS_DIR
from ravens_torch.environments.environment import EnvironmentOODBackground, Environment
from ravens_torch.dataset import load_data
from ravens_torch.utils import SummaryWriter
from ravens_torch.utils.initializers import set_seed
from ravens_torch.utils.text import bold


flags.DEFINE_string('train_dir', EXPERIMENTS_DIR, '')
flags.DEFINE_string('root_dir', EXPERIMENTS_DIR, help='Location of test data')
flags.DEFINE_string('data_dir', EXPERIMENTS_DIR, '')
flags.DEFINE_string('task', 'place-red-in-green-ood', '')
flags.DEFINE_string('agent', 'no_transport', '') 
flags.DEFINE_float('hz', 240, '')
flags.DEFINE_integer('n_demos', 300, '') 
flags.DEFINE_integer('gpu', 0, '')
flags.DEFINE_integer('gpu_limit', None, '')
flags.DEFINE_boolean('verbose', True, '')
flags.DEFINE_string('assets_root', ENV_ASSETS_DIR,
                    help='Location of assets directory to build the environment')
flags.DEFINE_bool('shared_memory', False, '')

FLAGS = flags.FLAGS  
####################################################################################################   
def train(train_run):
    
    ############################################################################################
    # Hyperparameters
    num_steps_total = 120 
    martingale_penalty = 10
    temperature = 1.0
    softrank_regularization_type = "l2"
    softrank_regularization_factor = 0.001
    
    num_detect_batches = 3 
    detect_batch_size = 200 
    
    bc_batch_size = 64 
    ############################################################################################
    
    
    ############################################################################################
    # Load datasets
    train_dataset, test_dataset = load_data(FLAGS)
    
    # Limit random sampling during training to a fixed dataset.
    max_demos = train_dataset.n_episodes
    episodes = np.random.choice(range(max_demos), FLAGS.n_demos, False)
    train_dataset.set(episodes)   
    ############################################################################################
    
    ############################################################################################
    # DRM training

    # Define name for DRM training
    drm_name = f'{FLAGS.task}-{FLAGS.agent}-{FLAGS.n_demos}-{train_run}'

    # Initialize DRM agent
    set_seed(train_run)
    drm_agent = agents.names[FLAGS.agent](
        drm_name, FLAGS.task, FLAGS.train_dir, verbose=FLAGS.verbose)

    # Create detection set from training dataset
    print("Creating detection sets from training dataset...")
    detect_dataset = drm_agent.create_detect_set(train_dataset, detect_batch_size, num_detect_batches)
    print(f"Done creating detection sets with {num_detect_batches} batch(es) of {detect_batch_size} each.")

    # Train DRM agent and save at the end
    while drm_agent.total_steps < num_steps_total:
        drm_agent.train_drm(train_dataset, detect_dataset, bc_batch_size,
                        martingale_penalty, 
                        temperature,
                        softrank_regularization_type,
                        softrank_regularization_factor)
                
            
    # Once training is done, save model
    drm_agent.save(FLAGS.verbose)

    # Evaluate on test environments
    test_success_drm = test_policy(train_run, drm_agent, test_dataset, ood=True)
    
    # Evaluate on subset of training environments
    train_success_drm = test_policy(train_run, drm_agent, test_dataset, ood=False)
    ############################################################################################
    
    ############################################################################################
    # Vanilla BC training

    # Define name for BC training
    bc_name = f'{FLAGS.task}-{FLAGS.agent}-{FLAGS.n_demos}-{train_run}-bc'

    # Initialize BC agent
    set_seed(train_run)
    bc_agent = agents.names[FLAGS.agent](
        bc_name, FLAGS.task, FLAGS.train_dir, verbose=FLAGS.verbose)

    # Train BC agent and save at the end
    while bc_agent.total_steps < num_steps_total:
        bc_agent.train(train_dataset, bc_batch_size)
                
            
    # Once training is done, save model
    bc_agent.save(FLAGS.verbose)

    # Evaluate on test environments
    test_success_bc = test_policy(train_run, bc_agent, test_dataset, ood=True)

    # Evaluate on subset of training environments
    train_success_bc = test_policy(train_run, bc_agent, test_dataset, ood=False)
    ############################################################################################
    
    return train_success_drm, train_success_bc, test_success_drm, test_success_bc
####################################################################################################  

####################################################################################################  
def test_policy(train_run, agent, ds, ood=True):
    
    if ood:
        # Initialize environment and task.
        env = EnvironmentOODBackground(
            FLAGS.assets_root,
            disp=False,
            shared_memory=FLAGS.shared_memory,
            hz=FLAGS.hz,
            background_urdf='ur5/workspace_test.urdf')
        task = tasks.names[FLAGS.task]('bowl/bowl_test.urdf')
    else:
        # Initialize environment and task.
        # Randomly select 1,2,3 for different training backgrounds / bowls
        idx = np.random.choice([1, 2, 3])
        env = EnvironmentOODBackground(
            FLAGS.assets_root,
            disp=False,
            shared_memory=FLAGS.shared_memory,
            hz=480,
            background_urdf=f'ur5/workspace_train{idx}.urdf')
        task = tasks.names[FLAGS.task](f'bowl/bowl_train{idx}.urdf')
        task._set_mode('test')
        print(bold("=" * 20 + "\n" + f"TASK: {FLAGS.task}" + "\n" + "=" * 20))

    # Initialize agent.
    set_seed(train_run)

    # Run testing and save total rewards with last transition info.
    results = []
    num_successes = 0
    
    for i in range(ds.n_episodes):
        print(f'Test: {i + 1}/{ds.n_episodes}')
        _, seed = ds.load(i)
        total_reward = 0
        np.random.seed(seed)
        env.seed(seed)
        env.set_task(task)
        obs = env.reset()
        info = None
        reward = 0

        for _ in range(task.max_steps):
            act = agent.act(obs, info) # , goal)
            obs, reward, done, info = env.step(act)
            total_reward += reward
            print(f'Total Reward: {total_reward} Done: {done}')
            if done:
                break
        results.append((total_reward, info))
        num_successes += reward
                
        print('Success rate: ', num_successes/(i+1))
        
    return num_successes/ds.n_episodes
####################################################################################################  


####################################################################################################  
def main(unused_argv):
    num_seeds = 10
    train_success_drm_all = []
    train_success_bc_all = []
    test_success_drm_all = []
    test_success_bc_all = []
    for seed in range(num_seeds):
        print(f"Running seed {seed}")
        train_success_drm, train_success_bc, test_success_drm, test_success_bc = train(seed)
        train_success_drm_all.append(train_success_drm)
        train_success_bc_all.append(train_success_bc)
        test_success_drm_all.append(test_success_drm)
        test_success_bc_all.append(test_success_bc)

    # Print results mean and std
    print(f"Train Success DRM: {np.mean(train_success_drm_all):.4f} ± {np.std(train_success_drm_all):.4f}")
    print(f"Train Success BC: {np.mean(train_success_bc_all):.4f} ± {np.std(train_success_bc_all):.4f}")
    print(f"Test Success DRM: {np.mean(test_success_drm_all):.4f} ± {np.std(test_success_drm_all):.4f}")
    print(f"Test Success BC: {np.mean(test_success_bc_all):.4f} ± {np.std(test_success_bc_all):.4f}")

    # Save results in results.npz file
    np.savez('results.npz', train_success_drm=train_success_drm_all, train_success_bc=train_success_bc_all,
             test_success_drm=test_success_drm_all, test_success_bc=test_success_bc_all)

####################################################################################################  

####################################################################################################  
if __name__ == '__main__':
    app.run(main)