from typing import Any, Dict
import numpy as np
import torch
from torch import nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from temporal_task_planner.trainer.slot_attention import SlotAttention


def generate_square_subsequent_mask(sz: int, diagonal=1):
    r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
    Unmasked positions are filled with float(0.0).
    """
    return torch.triu(torch.full((sz, sz), float("-inf")), diagonal=diagonal)


class TransformerTaskPlannerSingleModel(PreTrainedModel):
    """Encoding for category, pose, timesteps and reality marker
    with pick and place decoding heads
    """

    def __init__(
        self,
        config: PretrainedConfig,
        category_encoder: torch.nn.Module,
        pose_encoder: torch.nn.Module,
        temporal_encoder: torch.nn.Module,
        reality_marker_encoder: torch.nn.Module,
    ):
        super().__init__(config)
        self.category_encoder = category_encoder
        self.pose_encoder = pose_encoder
        self.temporal_encoder = temporal_encoder
        self.reality_marker_encoder = reality_marker_encoder
        encoder_layers = TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.d_hid,
            dropout=config.dropout,
            batch_first=config.batch_first,
        )
        self.transformer_encoder = TransformerEncoder(
            encoder_layers, config.num_encoder_layers
        )

    def instance_encoder(
        self, device, timestep, category, pose, is_real, **kwargs
    ) -> torch.Tensor:
        bs = pose.shape[0]
        msl = pose.shape[1]
        x_t = self.temporal_encoder(timestep.reshape(-1, 1).to(device))
        x_t = x_t.view(bs, msl, self.config.temporal_embed_size)

        x_c = self.category_encoder(category.reshape(-1, 3).to(device))
        x_c = x_c.view(bs, msl, self.config.category_embed_size)

        position = pose[:, :, : self.config.n_input_dim]
        x_p = self.pose_encoder(
            position.reshape(-1, self.config.n_input_dim).to(device)
        )
        x_p = x_p.view(bs, msl, self.config.pose_embed_size)

        x_m = self.reality_marker_encoder(is_real.long().reshape(-1, 1).to(device))
        x_m = x_m.view(bs, msl, self.config.marker_embed_size)

        x = torch.cat([x_m, x_t, x_c, x_p], dim=-1)
        return x

    def forward(
        self,
        timestep,
        category,
        pose,
        is_real,
        action_masks,
        category_token,
        instance_token,
        src_key_padding_mask=None,
        device=None,
    ):
        """
        Args:
            timestep: (B, N, 1) long tensor
            category (bounding box max extents): (B, N, 3) float tensor
            pose: (B, N, 7) float tensor
            action_masks: (B, N, 1) bool tensor
            is_real: (B, N, 1) bool tensor
            category_token: (B, N, 1), long tensor
            instance_token: (B, N, 1), long tensor
            src_key_padding_mask: (B, N, 1) bool tensor
            device: str ('cpu' / 'cuda')
        Return:
            pick_id : (A, N)
            feasible_placements : (A, T)  (sigmoid over each output)
                where A is the number of ACT tokens across the batch, N are the num of instances
        """
        batch_size = pose.shape[1]
        src_mask = generate_square_subsequent_mask(batch_size).to(device)
        x = self.instance_encoder(
            device,
            timestep,
            category,
            pose,
            is_real,
        )
        memory = self.transformer_encoder(
            x, src_mask.to(device), src_key_padding_mask.to(device)
        )
        raw_out_pick = memory @ x.permute(0, 2, 1) + generate_square_subsequent_mask(
            batch_size, diagonal=0
        ).to(device)
        out_pick = raw_out_pick[action_masks, :]
        out_place = None
        out = {"pick": out_pick, "place": out_place}
        return out


class PreferenceClassifier(TransformerTaskPlannerSingleModel):
    def __init__(self, *args, **kwargs):
        super(PreferenceClassifier, self).__init__(*args, **kwargs)
        self.slot_attention = SlotAttention(
            num_slots=self.config.num_slots,
            dim=self.config.d_model,
            iters=self.config.slot_iters,
            eps=1e-8,
            hidden_dim=self.config.d_hid,
        )
        self.final_logits = nn.Linear(self.config.num_slots * self.config.d_model, self.config.num_pref)

    def forward(self, 
        timestep,
        category,
        pose,
        is_real,
        action_masks,
        category_token,
        instance_token,
        src_key_padding_mask=None,
        device=None
    ):
        batch_size = pose.shape[0]
        max_seq_len = pose.shape[1]
        src_mask = generate_square_subsequent_mask(max_seq_len)
        x = self.instance_encoder(
            device,
            timestep,
            category,
            pose,
            is_real,
        )
        memory = self.transformer_encoder(
            x, 
            src_mask.to(device), 
            src_key_padding_mask.to(device)
        )
        memory = self.slot_attention(memory)
        out = self.final_logits(memory.view(batch_size, -1))
        return out