
from typing import Any

import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn.bert_padding import pad_input, unpad_input
from torch.utils.checkpoint import checkpoint

from tabicl.models.foundation.embedding import (FoundationEmbeddingX, FoundationEmbeddingYFloat,
                                                FoundationEmbeddingYInteger, FoundationQuantileEmbeddingX)


class FoundationTransformer(nn.Module):

    def __init__(
        self,
        dim_model: int,
        dim_embedding: int,
        dim_output: int,
        n_layers: int,
        n_heads: int,
        y_as_float_embedding: bool,
        quantile_embedding: bool,
        feature_count_scaling: bool,
        use_pretrained_weights: bool,
        path_to_weights: str,
    ) -> None:
        
        super().__init__()

        self.dim_model = dim_model
        self.dim_embedding = dim_embedding
        self.dim_output = dim_output
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.quantile_embedding = quantile_embedding
        self.feature_count_scaling = feature_count_scaling
        self.y_as_float_embedding = y_as_float_embedding

        if self.quantile_embedding:
            self.x_quantile = FoundationQuantileEmbeddingX(dim_model, dim_embedding, feature_count_scaling)
        
        self.x_embedding = FoundationEmbeddingX(dim_model, dim_embedding)

        self.y_embedding: torch.nn.Module
        if self.y_as_float_embedding:
            self.y_embedding = FoundationEmbeddingYFloat(dim_model)
        else:
            self.y_embedding = FoundationEmbeddingYInteger(dim_output, dim_model)

        self.layers = nn.ModuleList([])

        for _ in range(self.n_layers):
            self.layers.append(Layer(dim_model, self.n_heads))

        self.final_layer_norm = nn.LayerNorm(dim_model)
        self.final_layer1 = nn.Linear(dim_model, dim_model*4)
        self.final_layer2 = nn.Linear(dim_model*4, dim_output)

        if use_pretrained_weights:
            self.load_state_dict(torch.load(path_to_weights))



    def forward(
            self, 
            x_support: torch.Tensor, 
            y_support: torch.Tensor, 
            x_query__: torch.Tensor,
            padding_features: torch.Tensor,
            padding_obs_support: torch.Tensor,
            padding_obs_query__: torch.Tensor,
        ):

        """
        x_support is (batch_size, n_observations_support, n_features)
        y_support is (batch_size, n_observations_support)

        x_query is (batch_size, n_observations_query, n_features)

        returns:

        y_query is (batch_size, n_observations_query, n_classes)

        syntax:
        b = batch size
        s = number of observations
        d = dimension of embedding
        c = number of classes
        """

        batch_size = x_support.shape[0]
        n_obs_support = x_support.shape[1]
        n_obs_query__ = x_query__.shape[1]

        if self.quantile_embedding:
            x_support, x_query__ = self.x_quantile(x_support, x_query__, padding_obs_support, padding_features)
        x_support, x_query__ = self.x_embedding(x_support, x_query__)
        y_support, y_query__ = self.y_embedding(y_support, n_obs_query__)

        support = x_support + y_support
        query__ = x_query__ + y_query__

        x, pack = einops.pack((support, query__), 'b * d')
        
        for layer in self.layers:
            x = checkpoint(layer, x, pack, padding_obs_support, use_reentrant=False)

        x = self.final_layer_norm(x)
        x = self.final_layer1(x)
        x = F.gelu(x)
        x = self.final_layer2(x)

        support, query__ = einops.unpack(x, pack, 'b * c')

        return query__



class Layer(torch.nn.Module):

    def __init__(self, dim: int, n_heads: int) -> None:
        
        super().__init__()

        self.layer_norm1 = nn.LayerNorm(dim)
        self.attention = MultiheadAttention(dim, n_heads)
        self.layer_norm2 = nn.LayerNorm(dim)
        self.linear1 = nn.Linear(dim, dim*4)
        self.linear2 = nn.Linear(dim*4, dim)


    def forward(
            self, 
            x: torch.Tensor, 
            pack: list[Any],
            padding_mask: torch.Tensor
        ) -> torch.Tensor:

        batch_size, n_obs = x.shape[:2]
        padding_mask_total = torch.cat([padding_mask, torch.zeros((batch_size, n_obs - padding_mask.shape[1]), dtype=torch.bool, device=padding_mask.device)], dim=1)

        x_residual = x
        x = self.layer_norm1(x)

        support, query__ = einops.unpack(x, pack, 'b * d')
        att_support = self.attention(support, support, support, key_padding_mask=padding_mask)
        
        att_query__ = self.attention(query__, support, support, key_padding_mask=padding_mask)
        x = einops.pack((att_support, att_query__), 'b * d')[0]
        x = x_residual + x
        
        x_residual = x
        x = self.layer_norm2(x)
        x, x_indices, x_cu_seqlens, x_max_seqlen_in_batch = unpad_input(x, ~padding_mask_total)
        x = self.linear1(x)
        x = torch.nn.functional.gelu(x)
        x = self.linear2(x)
        x = pad_input(x, x_indices, batch_size, x_max_seqlen_in_batch)
        x = x_residual + x

        return x



class MultiheadAttention(torch.nn.Module):

    def __init__(self, dim: int, n_heads: int) -> None:
        
        super().__init__()

        self.use_flash_attention = False
        self.dim = dim
        self.n_heads = n_heads

        self.att = nn.MultiheadAttention(dim, n_heads, dropout=0.0, batch_first=True)



    def init_weights(self):
        pass
        # nn.init.zeros_(self.att.out_proj.weight)
        # nn.init.zeros_(self.att.out_proj.bias)

    
    def forward(
            self, 
            query: torch.Tensor,
            key: torch.Tensor,
            value: torch.Tensor, 
            key_padding_mask: torch.Tensor
        ) -> torch.Tensor:
        """
        b = batch size
        n = number of samples (dataset size)
        h = heads
        d = dimension of embedding

        query is (b, n, d)
        key is (b, n, d)
        value is (b, n, d)

        attention weights will be (b, h, n, n)
        output will be (b, n, d)
        """

        output = self.att(query, key, value, key_padding_mask=key_padding_mask)[0]
        return output
    