# coding=utf-8
# Adapted from Ravens - Transporter Networks, Zeng et al., 2021
# https://github.com/google-research/ravens

"""Transporter Agent."""

import os

import numpy as np
from ravens_torch.models.attention import Attention
from ravens_torch.models.transport import Transport
from ravens_torch.models.transport_ablation import TransportPerPixelLoss
from ravens_torch.models.transport_goal import TransportGoal
from ravens_torch.tasks import cameras
from ravens_torch.utils import utils
import torch
import IPython as ipy


class TransporterAgent:
    """Agent that uses Transporter Networks."""

    def __init__(self, name, task, root_dir, n_rotations=36):
        self.name = name
        self.task = task
        self.total_steps = 0
        self.crop_size = 64
        self.n_rotations = n_rotations
        self.pix_size = 0.003125
        self.in_shape = (320, 160, 3) # 6)
        self.cam_config = cameras.RealSenseD415.CONFIG
        self.models_dir = os.path.join(root_dir, 'checkpoints', self.name)
        self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]])

    def get_image(self, obs):
        """Stack color and height images image."""

        # if self.use_goal_image:
        #   colormap_g, heightmap_g = utils.get_fused_heightmap(goal, configs)
        #   goal_image = self.concatenate_c_h(colormap_g, heightmap_g)
        #   input_image = np.concatenate((input_image, goal_image), axis=2)
        #   assert input_image.shape[2] == 12, input_image.shape

        # Get color and height maps from RGB-D images.
        cmap, hmap = utils.get_fused_heightmap(
            obs, self.cam_config, self.bounds, self.pix_size)
        
        #     img = np.concatenate((cmap,
        #           hmap[Ellipsis, None],
        #           hmap[Ellipsis, None],
        #           hmap[Ellipsis, None]), axis=2)
        #     assert img.shape == self.in_shape, img.shape
        #     return img, _
        
        # Color image only 
        img = cmap 
        assert img.shape == self.in_shape, img.shape
        return img, hmap

    def get_sample(self, dataset, augment=False):
        """Get a dataset sample.

        Args:
          dataset: a ravens_torch.Dataset (train or validation)
          augment: if True, perform data augmentation.

        Returns:
          tuple of data for training:
            (input_image, p0, p0_theta, p1, p1_theta)
          tuple additionally includes (z, roll, pitch) if self.six_dof
          if self.use_goal_image, then the goal image is stacked with the
          current image in `input_image`. If splitting up current and goal
          images is desired, it should be done outside this method.
        """

        (obs, act, _, _), _ = dataset.sample()
        img, _ = self.get_image(obs)

        # Get training labels from data sample.
        p0_xyz, p0_xyzw = act['pose0']
        p1_xyz, p1_xyzw = act['pose1']
        p0 = utils.xyz_to_pix(p0_xyz, self.bounds, self.pix_size)
        p0_theta = -np.float32(utils.quatXYZW_to_eulerXYZ(p0_xyzw)[2])
        p1 = utils.xyz_to_pix(p1_xyz, self.bounds, self.pix_size)
        p1_theta = -np.float32(utils.quatXYZW_to_eulerXYZ(p1_xyzw)[2])
        p1_theta = p1_theta - p0_theta
        p0_theta = 0

        # Data augmentation.
        if augment:
            img, _, (p0, p1), _ = utils.perturb(img, [p0, p1])

        return img, p0, p0_theta, p1, p1_theta

    def create_detect_set(self, train_dataset, detect_batch_size, num_detect_batches=1, augment=False):
        """Create a detection set for training.
        
        Args:
            train_dataset: training dataset to create detection set from.
        Returns:
            detect_set: dictionary with the following
                obs: array of first time-step observations (detect_len, H, W, C).
                
        """
        
        # Get first time-step information from episodes in training dataset
        train_dataset = train_dataset.fetch_detect_set()    
        train_len = len(train_dataset)
        
        train_dataset_obs = np.zeros((train_len, self.in_shape[0], self.in_shape[1], self.in_shape[2]))
        train_dataset_p0 = np.zeros((train_len, 2), dtype=np.int32)
                
        for i in range(train_len):
            # Get observation
            (obs, act, _, _) = train_dataset[i]
            img, _ = self.get_image(obs)
            
            # Get training labels from data sample.
            p0_xyz, p0_xyzw = act['pose0']
            p1_xyz, p1_xyzw = act['pose1']
            p0 = utils.xyz_to_pix(p0_xyz, self.bounds, self.pix_size)
            p0_theta = -np.float32(utils.quatXYZW_to_eulerXYZ(p0_xyzw)[2])
            p1 = utils.xyz_to_pix(p1_xyz, self.bounds, self.pix_size)
            p1_theta = -np.float32(utils.quatXYZW_to_eulerXYZ(p1_xyzw)[2])
            p1_theta = p1_theta - p0_theta
            p0_theta = 0
            
            # Data augmentation.
            if augment:
                img, _, (p0, p1), _ = utils.perturb(img, [p0, p1])
                
            # Save
            train_dataset_obs[i] = img
            train_dataset_p0[i] = p0
            
        # Initialize detection set
        detect_sets_all = {'obs': np.zeros((num_detect_batches, detect_batch_size, self.in_shape[0], self.in_shape[1], self.in_shape[2])),
                           'p0': np.zeros((num_detect_batches, detect_batch_size, 2), dtype=np.int32)}
        
        # Initialize detection sets
        for batch in range(num_detect_batches):
            
            # Sample indices for this batch
            batch_indices = sorted(np.random.choice(range(train_len), detect_batch_size, replace=False))
            
            for i in range(detect_batch_size):
                    
                # Add to detect set
                detect_sets_all['obs'][batch][i] = train_dataset_obs[batch_indices[i]]
                detect_sets_all['p0'][batch][i] = train_dataset_p0[batch_indices[i]]
            
        return detect_sets_all
    
    def train(self, dataset, bc_batch_size, writer=None):
        """Train on a dataset sample for 1 iteration.

        Args:
          dataset: a ravens_torch.Dataset.
          writer: a TensorboardX SummaryWriter.
        """
        self.attention.train_mode()
        self.transport.train_mode()

        # Initialize batch
        img_batch = [None]*bc_batch_size
        p0_batch = [None]*bc_batch_size
        p0_theta_batch = [None]*bc_batch_size
        p1_batch = [None]*bc_batch_size
        p1_theta_batch = [None]*bc_batch_size
        
        for i in range(bc_batch_size):
            # Get a sample from the dataset.
            img, p0, p0_theta, p1, p1_theta = self.get_sample(dataset)

            # Store in batch.
            img_batch[i] = img
            p0_batch[i] = p0
            p0_theta_batch[i] = p0_theta
            p1_batch[i] = p1
            p1_theta_batch[i] = p1_theta
        
        # img, p0, p0_theta, p1, p1_theta = self.get_sample(dataset)

        # Get training losses.
        step = self.total_steps + 1
        loss0 = self.attention.train(img_batch, p0_batch, p0_theta_batch)
        if isinstance(self.transport, Attention):
            loss1 = self.transport.train(img_batch, p1_batch, p1_theta_batch)
        else:
            loss1 = self.transport.train(img_batch, p0_batch, p1_batch, p1_theta_batch)

        if writer is not None:
            # Log losses to Tensorboard
            writer.add_scalars([
                ('train_loss/attention', loss0, step),
                ('train_loss/transport', loss1, step),
            ])

        print(
            f'Train Iter: {step} \t Attention Loss: {loss0:.4f} \t Transport Loss: {loss1:.4f}')
        self.total_steps = step
        
        return loss0, loss1
        
    def train_drm(self, train_dataset, detect_dataset, bc_batch_size, martingale_penalty, temperature, softrank_type, softrank_factor, writer=None):
        """Train with martingale loss on a training dataset sample for 1 iteration.

        Args:
          train_dataset: a ravens_torch.Dataset.
          detect_dataset: array of first time-step observations (detect_len, H, W, C).
          writer: a TensorboardX SummaryWriter.
        """
        
        self.attention.train_mode()
        self.transport.train_mode()

        # img, p0, p0_theta, p1, p1_theta = self.get_sample(train_dataset)
        
        # Initialize batch
        img_batch = [None]*bc_batch_size
        p0_batch = [None]*bc_batch_size
        p0_theta_batch = [None]*bc_batch_size
        p1_batch = [None]*bc_batch_size
        p1_theta_batch = [None]*bc_batch_size
        
        for i in range(bc_batch_size):
            # Get a sample from the dataset.
            img, p0, p0_theta, p1, p1_theta = self.get_sample(train_dataset)

            # Store in batch.
            img_batch[i] = img
            p0_batch[i] = p0
            p0_theta_batch[i] = p0_theta
            p1_batch[i] = p1
            p1_theta_batch[i] = p1_theta

        # Get training losses.
        step = self.total_steps + 1
        # loss0, martingale_attention_av, martingale_attention_max = self.attention.train_drm(img_batch, p0_batch, p0_theta_batch, detect_dataset, martingale_penalty, temperature, softrank_type, softrank_factor)
        loss0 = self.attention.train(img_batch, p0_batch, p0_theta_batch)
        martingale_attention_av, martingale_attention_max = 0.0, 0.0
        if isinstance(self.transport, Attention):
            loss1, martingale_transport_av, martingale_transport_max = self.transport.train_drm(img_batch, p1_batch, p1_theta_batch, detect_dataset, martingale_penalty, temperature, softrank_type, softrank_factor)
        else:
            loss1, martingale_query_av, martingale_key_av, martingale_query_max, martingale_key_max = self.transport.train_drm(img_batch, p0_batch, p1_batch, p1_theta_batch, detect_dataset, martingale_penalty, temperature, softrank_type, softrank_factor)

        if writer is not None:
            # Log losses to Tensorboard
            writer.add_scalars([
                ('train_loss/attention', loss0, step),
                ('train_loss/transport', loss1, step),
            ])

        self.total_steps = step
        
        if isinstance(self.transport, Attention):
            print(
                f'Train Iter: {step} \t Attn. Loss: {loss0:.4f} \t Transp. Loss: {loss1:.4f} \t Martingale (A): {martingale_attention_av:.4f} / {martingale_attention_max:.4f} \t Martingale (T): {martingale_transport_av:.4f} / {martingale_transport_max:.4f}') 
            return martingale_attention_av, martingale_transport_av, 0.0, loss0, loss1
        else:
            print(
                f'Train Iter: {step} \t Attn. Loss: {loss0:.4f} \t Transp. Loss: {loss1:.4f} \t Martingale (A): {martingale_attention_av:.4f} / {martingale_attention_max:.4f} \t Martingale (Q): {martingale_query_av:.4f} / {martingale_query_max:.4f} \t Martingale (K): {martingale_key_av:.4f} / {martingale_key_max:.4f}')        
            return martingale_attention_av, martingale_query_av, martingale_key_av, loss0, loss1
        
        

    def validate(self, dataset, writer=None):  # pylint: disable=unused-argument
        """Test on a validation dataset for 10 iterations."""

        n_iter = 10
        loss0, loss1 = 0, 0
        for _ in range(n_iter):
            img, p0, p0_theta, p1, p1_theta = self.get_sample(dataset, False)

            # Get validation losses. Do not backpropagate.
            loss0 += self.attention.test(img, p0, p0_theta)
            if isinstance(self.transport, Attention):
                loss1 += self.transport.test(img, p1, p1_theta)
            else:
                loss1 += self.transport.test(img, p0, p1, p1_theta)
        loss0 /= n_iter
        loss1 /= n_iter

        writer.add_scalars([
            ('test_loss/attention', loss0, self.total_steps),
            ('test_loss/transport', loss1, self.total_steps),
        ])

        print(
            f'Validation: \t Attention Loss: {loss0:.4f} \t Transport Loss: {loss1:.4f}')

    def act(self, obs, info=None, goal=None):  # pylint: disable=unused-argument
        """Run inference and return best action given visual observations."""
        self.attention.eval_mode()
        self.transport.eval_mode()

        # Get heightmap from RGB-D images.
        img, hmap = self.get_image(obs)

        # Attention model forward pass.
        pick_conf = self.attention.forward(img)

        argmax = np.argmax(pick_conf)
        argmax = np.unravel_index(argmax, shape=pick_conf.shape)
        p0_pix = argmax[:2]
        p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])

        # Transport model forward pass.
        place_conf = self.transport.forward(img, p0_pix)
                
        argmax = np.argmax(place_conf)
        argmax = np.unravel_index(argmax, shape=place_conf.shape)
        p1_pix = argmax[:2]
        p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])

        # Pixels to end effector poses.
        # hmap = img[:, :, 3]
        p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size)
        p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size)
        p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
        p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))

        return {
            'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
            'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw))
        }

        # Make a goal image if needed, and for consistency stack with input.
        # if self.use_goal_image:
        #   cmap_g, hmap_g = utils.get_fused_heightmap(goal, self.cam_config)
        #   goal_image = self.concatenate_c_h(colormap_g, heightmap_g)
        #   input_image = np.concatenate((input_image, goal_image), axis=2)
        #   assert input_image.shape[2] == 12, input_image.shape

        # if self.use_goal_image:
        #   half = int(input_image.shape[2] / 2)
        #   input_only = input_image[:, :, :half]  # ignore goal portion
        #   pick_conf = self.attention.forward(input_only)
        # else:
        # if isinstance(self.transport, TransportGoal):
        #   half = int(input_image.shape[2] / 2)
        #   img_curr = input_image[:, :, :half]
        #   img_goal = input_image[:, :, half:]
        #   place_conf = self.transport.forward(img_curr, img_goal, p0_pix)

    def get_checkpoint_names(self, n_iter):
        attention_fname = 'attention-ckpt-%d.pth' % n_iter
        transport_fname = 'transport-ckpt-%d.pth' % n_iter

        attention_fname = os.path.join(self.models_dir, attention_fname)
        transport_fname = os.path.join(self.models_dir, transport_fname)

        return attention_fname, transport_fname

    def load(self, n_iter, verbose=False):
        """Load pre-trained models."""
        attention_fname, transport_fname = self.get_checkpoint_names(n_iter)

        self.attention.load(attention_fname, verbose)
        self.transport.load(transport_fname, verbose)
        self.total_steps = n_iter

    def save(self, verbose=False):
        """Save models."""
        if not os.path.exists(self.models_dir):
            os.makedirs(self.models_dir)
        attention_fname, transport_fname = self.get_checkpoint_names(
            self.total_steps)

        self.attention.save(attention_fname, verbose)
        self.transport.save(transport_fname, verbose)


# -----------------------------------------------------------------------------
# Other Transporter Variants
# -----------------------------------------------------------------------------


class OriginalTransporterAgent(TransporterAgent):

    def __init__(self, name, task, root_dir, n_rotations=36, verbose=False):
        super().__init__(name, task, root_dir, n_rotations)

        self.attention = Attention(
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            verbose=verbose,
            lite=True)
        self.transport = Transport(
            in_channels=self.in_shape[2],
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            verbose=verbose)


class NoTransportTransporterAgent(TransporterAgent):

    def __init__(self, name, task, root_dir, n_rotations=1, verbose=False):
        super().__init__(name, task, root_dir, n_rotations)

        self.attention = Attention(
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            verbose=verbose)
        self.transport = Attention(
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            preprocess=utils.preprocess,
            verbose=verbose)


class PerPixelLossTransporterAgent(TransporterAgent):

    def __init__(self, name, task, n_rotations=36, verbose=False):
        super().__init__(name, task, n_rotations)

        self.attention = Attention(
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            verbose=verbose)
        self.transport = TransportPerPixelLoss(
            in_channels=self.in_shape[2],
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            verbose=verbose)


class GoalTransporterAgent(TransporterAgent):
    """Goal-Conditioned Transporters supporting a separate goal FCN."""

    def __init__(self, name, task, n_rotations=36, verbose=False):
        super().__init__(name, task, n_rotations)

        self.attention = Attention(
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            verbose=verbose)
        self.transport = TransportGoal(
            in_channels=self.in_shape[2],
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            verbose=verbose)


class GoalNaiveTransporterAgent(TransporterAgent):
    """Naive version which stacks current and goal images through normal Transport."""

    def __init__(self, name, task, n_rotations=36, verbose=False):
        super().__init__(name, task, n_rotations)

        # Stack the goal image for the vanilla Transport module.
        t_shape = (self.in_shape[0], self.in_shape[1],
                   int(self.in_shape[2] * 2))

        self.attention = Attention(
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            verbose=verbose)
        self.transport = Transport(
            in_channels=t_shape[2],
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            verbose=verbose,
            per_pixel_loss=False,
            use_goal_image=True)
