import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from nvib_layer import Nvib

from utils import (
    EncoderTransformerConfig,
    DecoderTransformerConfig,
    TransformerLayerConfig,
    TransformerLayer,
    DenoisingCrossAttentionTransformerLayer
)

class EncoderTransformer(nn.Module):
    """PyTorch implementation of the Encoder Transformer."""
    config: EncoderTransformerConfig

    def __init__(self, config: EncoderTransformerConfig):
        super().__init__()
        self.config = config

        self.colors_embed = nn.Embedding(config.vocab_size, config.emb_dim, dtype=config.dtype)
        self.channels_embed = nn.Embedding(2, config.emb_dim, dtype=config.dtype)

        self.pos_embed = nn.Embedding(config.seq_len, config.emb_dim, dtype=config.dtype)

        self.cls_token_embed = nn.Embedding(1, config.emb_dim, dtype=config.dtype)

        self.embed_dropout = nn.Dropout(config.transformer_layer.dropout_rate)

        self.transformer_layers = nn.ModuleList(
            [TransformerLayer(config.transformer_layer) for _ in range(config.num_layers)]
        )

        self.cls_layer_norm = nn.LayerNorm(
            config.emb_dim,
            eps=1e-6,
            elementwise_affine=config.transformer_layer.use_bias,
            dtype=config.dtype
        )

        self.latent_mu_proj = nn.Linear(
            config.emb_dim, config.latent_dim, bias=config.latent_projection_bias, dtype=config.dtype
        )
        if config.variational:
            self.latent_logvar_proj = nn.Linear(
                config.emb_dim, config.latent_dim, bias=config.latent_projection_bias, dtype=config.dtype
            )
        else:
            self.register_parameter('latent_logvar_proj', None)

        self.nvib_layer = Nvib(config.emb_dim, config.latent_dim, disable_sampling=config.disable_sampling)


    def _get_pos_embed(self, device) -> torch.Tensor:
        """Generates the positional embeddings."""
        cfg = self.config
        seq_idx = torch.arange(cfg.seq_len, device=device)
        pos_embeds = self.pos_embed(seq_idx)
        return pos_embeds

    def embed_seq(self, pair: torch.Tensor) -> torch.Tensor:
        config = self.config
        base_dims = pair.shape[:-2]
        Seq = config.seq_len
        H = config.emb_dim
        device = pair.device

        pair_long = pair.long()

        pos_embed = self._get_pos_embed(device)
        pos_embed_bcast = pos_embed.view((1,) * len(base_dims) + (1, Seq, H))

        colors_embed = self.colors_embed(pair_long)

        channel_indices = torch.arange(2, device=device)
        channels_embed = self.channels_embed(channel_indices)
        channels_embed_bcast = channels_embed.view((1,) * len(base_dims) + (2, 1, H))

        x = colors_embed + pos_embed_bcast + channels_embed_bcast
        x = x.reshape(*base_dims, Seq * 2, H)

        cls_indices_shape = base_dims + (1,)
        cls_indices = torch.zeros(cls_indices_shape, dtype=torch.long, device=device)
        cls_token = self.cls_token_embed(cls_indices)

        x = torch.cat([cls_token, x], dim=-2)

        x = self.embed_dropout(x)
        return x

    def forward(
        self,
        pair: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        
        x = self.embed_seq(pair)

        original_shape = x.shape
        original_base_shape = original_shape[:-2]
        x = x.reshape(-1, x.shape[-2], x.shape[-1])
        if torch.isnan(x).any():
            raise ValueError("NaN detected in input tensor x before transformer layers.")
        for i, layer in enumerate(self.transformer_layers):
            x = layer(x, pad_mask=None)

        B, T, H = x.shape
        x = x.transpose(1, 0)
        latent_dict = self.nvib_layer(x, torch.zeros((B, T), dtype=torch.bool, device=torch.device('cuda')))
        latent_dict['original_base_shape'] = original_base_shape

        return latent_dict


import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict

from utils import (
    DecoderTransformerConfig,
    TransformerLayerConfig,
    DenoisingCrossAttentionTransformerLayer
)


class DecoderTransformer(nn.Module):
    """PyTorch implementation of the Decoder Transformer using Denoising Cross Attention."""
    config: DecoderTransformerConfig

    def __init__(self, config: DecoderTransformerConfig):
        super().__init__()
        self.config = config
        H = config.emb_dim
        R, C = config.max_rows, config.max_cols

        self.bos_token_embed = nn.Embedding(1, H, dtype=config.dtype)

        self.pos_seq_embed = nn.Embedding(1 + 2 * config.seq_len, H, dtype=config.dtype)

        self.input_output_embed = nn.Embedding(2, H, dtype=config.dtype)

        self.colors_embed = nn.Embedding(config.vocab_size, H, dtype=config.dtype)

        self.embed_dropout = nn.Dropout(config.transformer_layer.dropout_rate)

        self.transformer_layers = nn.ModuleList(
            [DenoisingCrossAttentionTransformerLayer(config.transformer_layer) for _ in range(config.num_layers)]
        )

        ln_affine = config.transformer_layer.use_bias
        ln_dtype = config.dtype
        self.row_logits_layer_norm = nn.LayerNorm(H, eps=1e-6, elementwise_affine=ln_affine, dtype=ln_dtype)
        self.col_logits_layer_norm = nn.LayerNorm(H, eps=1e-6, elementwise_affine=ln_affine, dtype=ln_dtype)
        self.grid_logits_layer_norm = nn.LayerNorm(H, eps=1e-6, elementwise_affine=ln_affine, dtype=ln_dtype)

        proj_bias = config.logits_projection_bias
        self.shape_row_logits_proj = nn.Linear(H, R, bias=proj_bias, dtype=config.dtype)
        self.shape_col_logits_proj = nn.Linear(H, C, bias=proj_bias, dtype=config.dtype)
        self.grid_logits_proj = nn.Linear(H, config.output_vocab_size, bias=proj_bias, dtype=config.dtype)

    def _get_pos_embed_pair(self, device) -> Tuple[torch.Tensor, torch.Tensor]:
        cfg = self.config
        R, C = cfg.max_rows, cfg.max_cols
        H = cfg.emb_dim
        rows_idx = torch.arange(R, device=device)
        cols_idx = torch.arange(C, device=device)

        input_pos_embed_grid, output_pos_embed_grid = None, None
        if cfg.scaled_position_embeddings:
            row_factors = (rows_idx + 1).unsqueeze(1).to(cfg.dtype)
            col_factors = (cols_idx + 1).unsqueeze(1).to(cfg.dtype)
            out_row = (row_factors * self.pos_row_embed_scale.unsqueeze(0)).unsqueeze(1)
            out_col = (col_factors * self.pos_col_embed_scale.unsqueeze(0)).unsqueeze(0)
            output_pos_embed_grid = out_row + out_col

            if cfg.next_position_embeddings and cfg.next_position_embeddings_new_input_embeds:
                 row_scale, col_scale = self.input_pos_row_embed_scale, self.input_pos_col_embed_scale
            else: row_scale, col_scale = self.pos_row_embed_scale, self.pos_col_embed_scale
            in_row = (row_factors * row_scale.unsqueeze(0)).unsqueeze(1)
            in_col = (col_factors * col_scale.unsqueeze(0)).unsqueeze(0)
            input_pos_embed_grid = in_row + in_col

        else:
            out_row = self.pos_row_embed(rows_idx).unsqueeze(1)
            out_col = self.pos_col_embed(cols_idx).unsqueeze(0)
            output_pos_embed_grid = out_row + out_col

            if cfg.next_position_embeddings and cfg.next_position_embeddings_new_input_embeds:
                 row_layer, col_layer = self.input_pos_row_embed, self.input_pos_col_embed
            else: row_layer, col_layer = self.pos_row_embed, self.pos_col_embed
            in_row = row_layer(rows_idx).unsqueeze(1)
            in_col = col_layer(cols_idx).unsqueeze(0)
            input_pos_embed_grid = in_row + in_col

        input_pos_embed_flat = input_pos_embed_grid.reshape(R * C, H)
        output_pos_embed_flat = output_pos_embed_grid.reshape(R * C, H)

        return input_pos_embed_flat, output_pos_embed_flat

    def embed_inputs(
        self, input_seq: torch.Tensor, output_seq: torch.Tensor
    ) -> torch.Tensor:
        config = self.config
        base_dims = input_seq.shape[:-1]
        L = config.seq_len
        H = config.emb_dim
        device = input_seq.device

        bos = input_seq[..., 0:1]
        input_seq = input_seq[..., 1:]

        input_seq_long = input_seq.long()
        output_seq_long = output_seq.long()
        bos = bos.long()

        bos_embed = self.bos_token_embed(bos)

        pos_seq_embeds_flat = self.pos_seq_embed(torch.arange(1 + 2 * L, device=device))
        pos_seq_embeds = pos_seq_embeds_flat.expand(base_dims + (1 + 2 * L, H))

        io_indices = torch.arange(2, device=device)
        input_embed, output_embed = self.input_output_embed(io_indices)
        input_embed = input_embed.view((1,) * len(base_dims) + (1, H))
        output_embed = output_embed.view((1,) * len(base_dims) + (1, H))

        input_colors_idx = input_seq_long[..., :]
        output_colors_idx = output_seq_long[..., :]
        input_colors_embed = self.colors_embed(input_colors_idx)
        output_colors_embed = self.colors_embed(output_colors_idx)

        x_input_colors = input_colors_embed + input_embed
        x_output_colors = output_colors_embed + output_embed

        x = torch.cat(
            [   
                bos_embed + pos_seq_embeds[..., 0:1, :],
                x_input_colors + pos_seq_embeds[..., 1:1+L, :],
                x_output_colors + + pos_seq_embeds[..., 1+L:1+2*L, :],
            ],
            dim=-2,
        )
        x = self.embed_dropout(x)
        return x

    def forward(
        self,
        input_seq: torch.Tensor,
        output_seq: torch.Tensor,
        nvib_output: Dict[str, torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if input_seq.dim() == 3:
            num_pairs = input_seq.shape[1]
        else:
            num_pairs = None
        if 'original_base_shape' not in nvib_output:
            raise ValueError("Decoder requires 'original_base_shape' in nvib_output dictionary.")
        original_base_shape = nvib_output['original_base_shape']
        
        x = self.embed_inputs(input_seq, output_seq)

        T, H = x.shape[-2], x.shape[-1]
        x_flat = x.reshape(-1, T, H)

        batch_total_flat = x_flat.shape[0]

        flat_nvib_output = {}
        nl = -1
        representative_key = 'memory_key_padding_mask' if 'memory_key_padding_mask' in nvib_output else 'z'
        rep_tensor_orig = nvib_output[representative_key]
        if isinstance(rep_tensor_orig, tuple): rep_tensor_orig = rep_tensor_orig[0]

        needs_flattening = rep_tensor_orig.shape[0] != batch_total_flat

        keys_to_copy = ['z', 'memory_key_padding_mask']
        for key in keys_to_copy:
            if key not in nvib_output: continue
            value = nvib_output[key]
            if isinstance(value, torch.Tensor):
                 if needs_flattening and len(value.shape) > 1 :
                     Nl = value.shape[1]
                     value = value.unsqueeze(1).expand(-1, num_pairs, -1,)
                     flat_nvib_output[key] = value.reshape(batch_total_flat, Nl)
                 else:
                     flat_nvib_output[key] = value
            elif isinstance(value, tuple):
                 new_tuple = []
                 for tensor in value:
                      if needs_flattening and len(tensor.shape) > 1:
                          Nl = tensor.shape[0]
                          trailing_shape = tensor.shape[-1]
                          tensor = tensor.unsqueeze(2).expand(-1, -1, num_pairs, -1)
                          new_tuple.append(tensor.reshape(Nl, batch_total_flat, trailing_shape))
                      else:
                          new_tuple.append(tensor)
                 flat_nvib_output[key] = tuple(new_tuple)

        if 'memory_key_padding_mask' not in flat_nvib_output:
             raise ValueError("'memory_key_padding_mask' missing from nvib_output/flat_nvib_output.")
        if nl == -1:
             nl = flat_nvib_output['memory_key_padding_mask'].shape[-1]
        
        if needs_flattening:
            original_base_shape = torch.Size([original_base_shape] + [num_pairs])


        for layer in self.transformer_layers:
            x_flat = layer(
                decoder_embeddings=x_flat,
                nvib_output=flat_nvib_output,
                decoder_key_padding_mask=None,
                causal_mask=None
            )

        input_len = input_seq.shape[-1]
        grid_logits_flat = self.extract_logits(x_flat, input_len)

        L_out, V_out = grid_logits_flat.shape[-2:]

        grid_logits = grid_logits_flat.reshape(*original_base_shape, L_out, V_out) if type(original_base_shape) == torch.Size else grid_logits_flat.reshape(original_base_shape, L_out, V_out)
        return grid_logits.float()

    def extract_logits(self, x_flat: torch.Tensor, input_seq_len: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        config = self.config
        L = config.seq_len
        BatchTotal, T, H = x_flat.shape

        idx_last_input_token = input_seq_len - 1

        grid_embeds_for_pred = x_flat[:, idx_last_input_token : -1, :]

        grid_embeds = grid_embeds_for_pred

        grid_logits = self.grid_logits_proj(grid_embeds)

        if grid_logits.shape[1] != L:
             if grid_logits.shape[1] < L:
                 pad_amount = L - grid_logits.shape[1]
                 padding = torch.zeros(BatchTotal, pad_amount, grid_logits.shape[-1], dtype=grid_logits.dtype, device=grid_logits.device)
                 grid_logits = torch.cat([grid_logits, padding], dim=1)
             else:
                 grid_logits = grid_logits[:, :L, :]
        
        return grid_logits


