"""Plot martingale values across training dataset."""

####################################################################################################   
from absl import app, flags
import numpy as np
import torch
import matplotlib.pyplot as plt

from ravens_torch import agents, tasks
from ravens_torch.constants import EXPERIMENTS_DIR, ENV_ASSETS_DIR
from ravens_torch.environments.environment import EnvironmentOODBackground, Environment
from ravens_torch.dataset import load_data
from ravens_torch.utils import SummaryWriter
from ravens_torch.utils.initializers import set_seed
from ravens_torch.utils.text import bold
from ravens_torch.utils.loss import compute_martingale_hard


flags.DEFINE_string('train_dir', EXPERIMENTS_DIR, '')
flags.DEFINE_string('root_dir', EXPERIMENTS_DIR, help='Location of test data')
flags.DEFINE_string('data_dir', EXPERIMENTS_DIR, '')
flags.DEFINE_string('task', 'place-red-in-green-ood', '') 
flags.DEFINE_string('agent', 'no_transport', '') 
flags.DEFINE_float('hz', 240, '')
flags.DEFINE_integer('n_demos', 300, '') 
flags.DEFINE_integer('gpu', 0, '')
flags.DEFINE_integer('gpu_limit', None, '')
flags.DEFINE_boolean('verbose', True, '')
flags.DEFINE_integer('save_interval', 10, '') 
flags.DEFINE_integer('test_interval', 10, '') 
flags.DEFINE_string('assets_root', ENV_ASSETS_DIR,
                    help='Location of assets directory to build the environment')
flags.DEFINE_bool('shared_memory', False, '')

FLAGS = flags.FLAGS  
####################################################################################################   
def plot_martingale(unused_argv):
    
    ############################################################################################
    num_train_steps = 120 # 900
    detect_batch_size = FLAGS.n_demos 
    num_detect_batches = 1
    ############################################################################################
    
    # Load train and test datasets.
    train_dataset, test_dataset = load_data(FLAGS)


    #############################################################################################
    # Load pre-trained DRM agent
    train_run = 0
    name = f'{FLAGS.task}-{FLAGS.agent}-{FLAGS.n_demos}-{train_run}'

    # Initialize agent.
    set_seed(train_run)
    agent = agents.names[FLAGS.agent](
        name, FLAGS.task, FLAGS.train_dir, verbose=FLAGS.verbose)

    # Limit random sampling during training to a fixed dataset.
    max_demos = train_dataset.n_episodes
    episodes = np.random.choice(range(max_demos), FLAGS.n_demos, False)
    train_dataset.set(episodes)
    
    # Load pre-trained agent
    agent.load(num_train_steps, FLAGS.verbose)
    
    # Create detection set
    print("Creating detection set.")
    detect_dataset = agent.create_detect_set(train_dataset, detect_batch_size, num_detect_batches)  
    detect_dataset = detect_dataset['obs'][0]
    print("Done.")

    # Compute features for transport agent
    transport_agent = agent.transport
    transport_agent.eval_mode()
    with torch.no_grad():
        
        # Get features for first data point
        features_0 = transport_agent.get_features(detect_dataset[0])

        # Get feature dimension
        feature_dim = features_0.shape[0]
        
        # Initialize all features
        features_attention = torch.zeros((detect_dataset.shape[0], feature_dim), device=transport_agent.device)
        features_attention[0,:] = features_0
        for i in range(1, detect_dataset.shape[0]):
            features_attention[i] = transport_agent.get_features(detect_dataset[i])
            
    # Compute martingale
    _, _, martingale_all = compute_martingale_hard(features_attention, transport_agent.device)
    
    # Turn into numpy
    martingale_all_drm = martingale_all.detach().cpu().numpy()
    
    
    ###############################################################################################
    # Load pre-trained BC agent
    train_run = 0
    name = f'{FLAGS.task}-{FLAGS.agent}-{FLAGS.n_demos}-{train_run}-bc'

    # Initialize agent.
    set_seed(train_run)
    agent_bc = agents.names[FLAGS.agent](
        name, FLAGS.task, FLAGS.train_dir, verbose=FLAGS.verbose)
    
    # Load pre-trained agent
    agent_bc.load(num_train_steps, FLAGS.verbose)

    # Compute features for transport agent
    transport_agent_bc = agent_bc.transport
    transport_agent_bc.eval_mode()
    with torch.no_grad():
        
        # Get features for first data point
        features_0 = transport_agent_bc.get_features(detect_dataset[0])

        # Get feature dimension
        feature_dim = features_0.shape[0]
        
        # Initialize all features
        features_attention = torch.zeros((detect_dataset.shape[0], feature_dim), device=transport_agent_bc.device)
        features_attention[0,:] = features_0
        for i in range(1, detect_dataset.shape[0]):
            features_attention[i] = transport_agent_bc.get_features(detect_dataset[i])
            
    # Compute martingale
    _, _, martingale_all_bc = compute_martingale_hard(features_attention, transport_agent_bc.device)
    
    # Turn into numpy
    martingale_all_bc = martingale_all_bc.detach().cpu().numpy()
    ###############################################################################################
    
    ###############################################################################################
    # Now compute martingale from raw images
    raw_images = detect_dataset.reshape(detect_dataset.shape[0], -1)
    raw_images = torch.tensor(raw_images, device=transport_agent.device)
    _, _, martingale_all_images = compute_martingale_hard(raw_images, transport_agent.device)
    martingale_all_images = martingale_all_images.detach().cpu().numpy()
    ###############################################################################################

    ###############################################################################################
    # Plot martingales
    # Set figure size to match 4:1 aspect ratio (e.g., 12 wide x 3 high)
    plt.figure(figsize=(12, 3))

    # Plot each line with the specified color
    plt.plot(martingale_all_images, color="#bacee4", label='Raw images', linewidth=5)
    plt.plot(martingale_all_bc, color='#72acea', label='ERM features', linewidth=5)
    plt.plot(martingale_all_drm, color='#db4042', label='DRM features (ours)', linewidth=5)

    # Set labels with larger font size
    plt.xlabel('Environment #', fontsize=20)
    plt.ylabel('Martingale', fontsize=20)

    # Set y-axis limits
    plt.ylim(0, 100)

    # Font size for ticks
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)

    # Add legend
    plt.legend(fontsize=20)

    # Save image
    plt.savefig('martingales.pdf', bbox_inches='tight')
    
    # Show plot
    plt.show()
    ###############################################################################################

####################################################################################################  
if __name__ == '__main__':
    app.run(plot_martingale)