# coding=utf-8
# Adapted from Ravens - Transporter Networks, Zeng et al., 2021
# https://github.com/google-research/ravens

"""Attention module."""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from ravens_torch.utils import utils, compute_martingale, MeanMetrics, to_device
from ravens_torch.utils.text import bold
from ravens_torch.utils.utils import apply_rotations_to_tensor
from ravens_torch.models.resnet import SimpleNet_attention

from einops.layers.torch import Rearrange
import IPython as ipy

# REMOVE BELOW
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE"


class Attention:
    """Attention module."""

    def __init__(self, in_shape, n_rotations, preprocess, lite=False, verbose=False):
        self.n_rotations = n_rotations
        self.preprocess = preprocess

        max_dim = np.max(in_shape[:2])

        self.padding = np.zeros((3, 2), dtype=int)
        pad = (max_dim - np.array(in_shape[:2])) / 2
        self.padding[:2] = pad.reshape(2, 1)
        
        self.padding_batch = np.concatenate(([[0,0]], self.padding), 0) # Padding for batched inputs

        # Initialize fully convolutional Residual Network 
        model_type = SimpleNet_attention 
        self.model = model_type(in_shape[2], 1)

        self.device = to_device([self.model], "Attention", verbose=verbose)

        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) 
        self.loss = nn.CrossEntropyLoss(reduction="mean")

        self.metric = MeanMetrics()

    def forward(self, in_img, softmax=True):
        """Forward pass."""
        in_data = np.pad(in_img, self.padding, mode='constant')
        in_data = self.preprocess(in_data)
        in_shape = (1,) + in_data.shape
        in_data = in_data.reshape(in_shape)
        in_tens = torch.tensor(in_data, dtype=torch.float32).to(self.device)

        # Rotate input.
        in_tens = apply_rotations_to_tensor(in_tens, self.n_rotations)

        # Forward pass.
        in_tens = torch.split(in_tens, 1, dim=0)  # (self.num_rotations)
        logits = ()
        for x in in_tens:
            logits += (self.model(x),)
        logits = torch.cat(logits, dim=0)

        # Rotate back output.
        logits = apply_rotations_to_tensor(
            logits, self.n_rotations, reverse=True)

        c0 = self.padding[:2, 0]
        c1 = c0 + in_img.shape[:2]
        logits = logits[:, c0[0]:c1[0], c0[1]:c1[1], :]

        output = Rearrange('b h w c -> b (h w c)')(logits)

        if softmax:
            output = nn.Softmax(dim=1)(output)
            output = output.detach().cpu().numpy()
            output = np.float32(output).reshape(logits.shape[1:])
        return output

    def get_features(self, in_img_batch):
        """Forward pass where we get features from model."""
        
        # Check if input is a single image or a batch
        if len(in_img_batch.shape) == 3:
            # Single image
            in_img_batch = np.expand_dims(in_img_batch, axis=0)
        
        in_data = np.pad(in_img_batch, self.padding_batch, mode='constant')
        in_data = self.preprocess(in_data)
        in_tens = torch.tensor(in_data, dtype=torch.float32).to(self.device)
        
        if self.n_rotations > 1:
            raise NotImplementedError("get_features not implemented for n_rotations > 1")
        
        features = self.model(in_tens, return_features=True)
        
        return features
    
    def train_block(self, in_img, p, theta):
        output = self.forward(in_img, softmax=False)

        # Get label.
        theta_i = theta / (2 * np.pi / self.n_rotations)
        theta_i = np.int32(np.round(theta_i)) % self.n_rotations
        label_size = in_img.shape[:2] + (self.n_rotations,)
        label = np.zeros(label_size)
        label[p[0], p[1], theta_i] = 1
        label = torch.tensor(label, dtype=torch.float32).to(self.device)

        # Get loss.
        label = Rearrange('h w c -> 1 (h w c)')(label)
        label = torch.argmax(label, dim=1)

        loss = self.loss(output, label)

        return loss

    def train(self, in_img_batch, p_batch, theta_batch):
        """Train."""
        self.metric.reset()
        self.train_mode()
        self.optimizer.zero_grad()
                
        batch_size = len(in_img_batch)
        loss = torch.tensor(0.0, device=self.device)
        for i in range(batch_size):
            in_img = in_img_batch[i]
            p = p_batch[i]
            theta = theta_batch[i]

            # Train on each image in the batch
            loss += self.train_block(in_img, p, theta) / batch_size

        # loss = self.train_block(in_img_batch, p_batch, theta_batch)
        loss.backward()
        self.optimizer.step()
        self.metric(loss)

        return np.float32(loss.detach().cpu().numpy())
    
    def train_drm(self, in_img_batch, p_batch, theta_batch, detect_dataset, martingale_penalty, temperature, softrank_type, softrank_factor):
        """Train with DRM."""
        self.metric.reset()
        self.train_mode()
        self.optimizer.zero_grad()

        # Imitation loss
        batch_size = len(in_img_batch)
        loss_imitation = torch.tensor(0.0, device=self.device)
        for i in range(batch_size):
            in_img = in_img_batch[i]
            p = p_batch[i]
            theta = theta_batch[i]

            # Train on each image in the batch
            loss_imitation += self.train_block(in_img, p, theta) / batch_size
                    
        # Load detection set
        detect_datasets_all = detect_dataset['obs']
        num_detect_batches = detect_datasets_all.shape[0]
        detect_batch_size = detect_datasets_all.shape[1]
        
        # Initialize martingale losses
        loss_martingale_batch_av = torch.tensor(0.0, device=self.device)
        martingale_hard_batch_av = torch.tensor(0.0, device=self.device)
        martingale_hard_max_batch_av = torch.tensor(0.0, device=self.device)
        
        for batch in range(num_detect_batches):
            
            detect_dataset_batch = detect_datasets_all[batch]
            
            # Get features for first data point
            features_0 = self.get_features(detect_dataset_batch[0])
            
            # Get feature dimension
            feature_dim = features_0.shape[0]
            
            # Initialize all features
            features_attention = torch.zeros((detect_batch_size, feature_dim), device=self.device) 
            features_attention[0,:] = features_0
            for i in range(1, detect_batch_size):
                features_attention[i] = self.get_features(detect_dataset_batch[i])
        
        
            # Compute martingale
            loss_martingale, martingale_hard_av, martingale_hard_max = compute_martingale(features_attention, self.device, temperature, softrank_type, softrank_factor)
            
            # Accumulate losses
            loss_martingale_batch_av += loss_martingale[0] / num_detect_batches
            martingale_hard_batch_av += martingale_hard_av / num_detect_batches
            martingale_hard_max_batch_av += martingale_hard_max / num_detect_batches
        
        # Compute total loss
        loss = 0.001*loss_imitation + martingale_penalty*loss_martingale_batch_av
        
        loss.backward()
        self.optimizer.step()
        self.metric(loss)
        
        return np.float32(loss_imitation.detach().cpu().numpy()), martingale_hard_batch_av, martingale_hard_max_batch_av

    def compute_martingale_all(self, detect_dataset, temperature, softrank_type, softrank_factor):
        """Compute martingale for all data points in the detection dataset.
        """
        
        self.eval_mode()
        with torch.no_grad():
            martingale_values = []
            
            # Get features for first data point
            features_0 = self.get_features(detect_dataset[0])

            # Get feature dimension
            feature_dim = features_0.shape[0]
            
            # Initialize all features
            features_attention = torch.zeros((detect_dataset.shape[0], feature_dim), device=self.device) 
            features_attention[0,:] = features_0
            for i in range(1, detect_dataset.shape[0]):
                features_attention[i] = self.get_features(detect_dataset[i])

            # Compute martingale
            _, _, _, martingale_values_all = compute_martingale(features_attention, self.device, temperature, softrank_type, softrank_factor)
                
        return martingale_values_all

    def test(self, in_img, p, theta):
        """Test."""
        self.eval_mode()

        with torch.no_grad():
            loss = self.train_block(in_img, p, theta)

        return np.float32(loss.detach().cpu().numpy())

    def train_mode(self):
        self.model.train()

    def eval_mode(self):
        self.model.eval()

    def load(self, path, verbose=False):
        if verbose:
            device = "GPU" if self.device.type == "cuda" else "CPU"
            print(
                f"Loading {bold('attention')} model on {bold(device)} from {bold(path)}")
        self.model.load_state_dict(torch.load(path, map_location=self.device))

    def save(self, filename, verbose=False):
        if verbose:
            print(f"Saving attention model to {bold(filename)}")
        torch.save(self.model.state_dict(), filename)
