from typing import Any, Dict
import numpy as np
import torch
from torch import nn
from torch.nn import (
    TransformerEncoderLayer,
    TransformerEncoder,
    TransformerDecoderLayer,
    TransformerDecoder,
)
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from temporal_task_planner.trainer.transformer.single_model import (
    TransformerTaskPlannerSingleModel,
    generate_square_subsequent_mask,
)
from temporal_task_planner.trainer.slot_attention import SlotAttention


class TransformerTaskPlannerDualModel(TransformerTaskPlannerSingleModel):
    """Encoding for category, pose, timesteps and reality marker
    with pick and place decoding heads
    """

    def __init__(
        self,
        config: PretrainedConfig,
        category_encoder: torch.nn.Module,
        pose_encoder: torch.nn.Module,
        temporal_encoder: torch.nn.Module,
        reality_marker_encoder: torch.nn.Module,
    ):
        super().__init__(
            config,
            category_encoder,
            pose_encoder,
            temporal_encoder,
            reality_marker_encoder,
        )
        encoder_layers = TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.d_hid,
            dropout=config.dropout,
            batch_first=config.batch_first,
        )
        self.encoder = TransformerEncoder(encoder_layers, config.num_encoder_layers)
        self.slot_attention = SlotAttention(
            num_slots=config.num_slots,
            dim=config.d_model,
            iters=config.slot_iters,
            eps=1e-8,
            hidden_dim=config.d_hid,
        )
        decoder_layers = TransformerDecoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.d_hid,
            dropout=config.dropout,
            batch_first=config.batch_first,
        )
        self.decoder = TransformerDecoder(decoder_layers, config.num_decoder_layers)

    def forward(
        self,
        prompt,
        situation,
        device,
    ):
        """
        Args:
            # timestep: (B, N, 1) long tensor
            # category (bounding box max extents): (B, N, 3) float tensor
            # pose: (B, N, 7) float tensor
            # action_masks: (B, N, 1) bool tensor
            # is_real: (B, N, 1) bool tensor
            # category_token: (B, N, 1), long tensor
            # instance_token: (B, N, 1), long tensor
            # src_key_padding_mask: (B, N, 1) bool tensor

            prompt
            situation
            device: str ('cpu' / 'cuda')
        Return:
            pick_id : (A, N)
            feasible_placements : (A, T)  (sigmoid over each output)
                where A is the number of ACT tokens across the batch, N are the num of instances
        """
        src_mask = generate_square_subsequent_mask(prompt["timestep"].shape[1]).to(
            device
        )
        prompt_encodings = self.instance_encoder(
            device,
            prompt["timestep"],
            prompt["category"],
            prompt["pose"],
            prompt["is_real"],
        )
        memory = self.encoder(
            prompt_encodings,
            mask=src_mask.to(device),
            src_key_padding_mask=prompt["src_key_padding_mask"].to(device),
        )
        memory = self.slot_attention(memory)
        situation_encodings = self.instance_encoder(
            device,
            situation["timestep"],
            situation["category"],
            situation["pose"],
            situation["is_real"],
        )
        tgt_mask = generate_square_subsequent_mask(situation["timestep"].shape[1]).to(
            device
        )
        tgt = self.decoder(
            situation_encodings,
            memory,
            tgt_mask=tgt_mask.to(device),
            tgt_key_padding_mask=situation["src_key_padding_mask"].to(device),
        )
        raw_out_pick = tgt @ situation_encodings.permute(
            0, 2, 1
        ) + generate_square_subsequent_mask(
            situation["timestep"].shape[1], diagonal=0
        ).to(
            device
        )
        out_pick = raw_out_pick[situation["action_masks"], :]
        out_place = None
        out = {"pick": out_pick, "place": out_place}
        return out
