import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from torch.utils.checkpoint import checkpoint

from tabicl.models.foundation_flash.embedding import (FoundationEmbeddingX, FoundationEmbeddingYFloat,
                                                      FoundationEmbeddingYInteger, FoundationQuantileEmbeddingX)
from tabicl.models.foundation_flash.padder import Padder


class FoundationFlashTransformer(nn.Module):

    def __init__(
        self,
        dim_model: int,
        dim_embedding: int,
        dim_output: int,
        n_layers: int,
        n_heads: int,
        y_as_float_embedding: bool,
        quantile_embedding: bool,
        feature_count_scaling: bool,
        use_pretrained_weights: bool,
        path_to_weights: str,
    ) -> None:
        
        super().__init__()

        self.dim_model = dim_model
        self.dim_embedding = dim_embedding
        self.dim_output = dim_output
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.quantile_embedding = quantile_embedding
        self.feature_count_scaling = feature_count_scaling
        self.y_as_float_embedding = y_as_float_embedding

        if self.quantile_embedding:
            self.x_quantile = FoundationQuantileEmbeddingX(dim_model, dim_embedding, feature_count_scaling)
        
        self.x_embedding = FoundationEmbeddingX(dim_model, dim_embedding)

        self.y_embedding: torch.nn.Module
        if self.y_as_float_embedding:
            self.y_embedding = FoundationEmbeddingYFloat(dim_model)
        else:
            self.y_embedding = FoundationEmbeddingYInteger(dim_output, dim_model)

        self.layers = nn.ModuleList([])

        for _ in range(self.n_layers):
            self.layers.append(Layer(dim_model, self.n_heads))

        self.final_layer_norm = nn.LayerNorm(dim_model)
        self.final_layer1 = nn.Linear(dim_model, dim_model*4)
        self.final_layer2 = nn.Linear(dim_model*4, dim_output)

        if use_pretrained_weights:
            self.load_state_dict(torch.load(path_to_weights))



    def forward(
            self, 
            x_support: torch.Tensor, 
            y_support: torch.Tensor, 
            x_query__: torch.Tensor,
            padding_features: torch.Tensor,
            padding_obs_support: torch.Tensor,
            padding_obs_query__: torch.Tensor,
        ):

        """
        x_support is (batch_size, n_observations_support, n_features)
        y_support is (batch_size, n_observations_support)

        x_query is (batch_size, n_observations_query, n_features)

        returns:

        y_query is (batch_size, n_observations_query, n_classes)

        syntax:
        b = batch size
        s = number of observations
        d = dimension of embedding
        c = number of classes
        """

        batch_size = x_support.shape[0]
        n_obs_support = x_support.shape[1]
        n_obs_query__ = x_query__.shape[1]

        if self.quantile_embedding:
            x_support, x_query__ = self.x_quantile(x_support, x_query__, padding_obs_support, padding_features)
        x_support, x_query__ = self.x_embedding(x_support, x_query__)
        y_support, y_query__ = self.y_embedding(y_support, n_obs_query__)

        support = x_support + y_support
        query__ = x_query__ + y_query__

        padder_support = Padder(support, padding_obs_support)
        padder_query__ = Padder(query__, padding_obs_query__)
        
        support = padder_support.base_to_obs(support)
        query__ = padder_query__.base_to_obs(query__)
        
        for layer in self.layers:
            support, query__ = checkpoint(layer, support, query__, padder_support, padder_query__, use_reentrant=False)

        query__ = self.final_layer_norm(query__)
        query__ = self.final_layer1(query__)
        query__ = F.gelu(query__)
        query__ = self.final_layer2(query__)

        query__ = padder_query__.obs_to_base(query__)

        return query__



class Layer(torch.nn.Module):

    def __init__(self, dim: int, n_heads: int) -> None:
        
        super().__init__()

        self.layer_norm1 = nn.LayerNorm(dim)
        self.attention = MultiheadAttention(dim, n_heads)
        self.layer_norm2 = nn.LayerNorm(dim)
        self.linear1 = nn.Linear(dim, dim*4)
        self.linear2 = nn.Linear(dim*4, dim)


    def forward(
            self, 
            support: torch.Tensor,
            query__: torch.Tensor,
            padder_support: Padder,
            padder_query__: Padder,
        ) -> tuple[torch.Tensor, torch.Tensor]:


        support_residual = support
        query___residual = query__

        support = self.layer_norm1(support)
        query__ = self.layer_norm1(query__)

        support = self.attention(support, support, support,
            cu_seqlens_q = padder_support.cu_seqlens_o, max_seqlen_q = padder_support.max_seqlen_in_batch_o, 
            cu_seqlens_k = padder_support.cu_seqlens_o, max_seqlen_k = padder_support.max_seqlen_in_batch_o
        )

        query__ = self.attention(query__, support, support,
            cu_seqlens_q = padder_query__.cu_seqlens_o, max_seqlen_q = padder_query__.max_seqlen_in_batch_o,
            cu_seqlens_k = padder_support.cu_seqlens_o, max_seqlen_k = padder_support.max_seqlen_in_batch_o
        )

        support = support_residual + support
        query__ = query___residual + query__

        support_residual = support
        query___residual = query__

        support = self.layer_norm2(support)
        query__ = self.layer_norm2(query__)

        support = self.linear1(support)
        query__ = self.linear1(query__)

        support = F.gelu(support)
        query__ = F.gelu(query__)

        support = self.linear2(support)
        query__ = self.linear2(query__)

        support = support_residual + support
        query__ = query___residual + query__

        return support, query__



class MultiheadAttention(torch.nn.Module):

    def __init__(self, dim: int, n_heads: int) -> None:
        
        super().__init__()

        self.use_flash_attention = False
        self.dim = dim
        self.n_heads = n_heads

        self.q = nn.Linear(dim, dim, bias=True)
        self.k = nn.Linear(dim, dim, bias=True)
        self.v = nn.Linear(dim, dim, bias=True)
        self.o = nn.Linear(dim, dim, bias=True)

    
    def forward(
            self, 
            query: torch.Tensor,
            key: torch.Tensor,
            value: torch.Tensor, 
            cu_seqlens_q: torch.Tensor,
            cu_seqlens_k: torch.Tensor,
            max_seqlen_q: int,
            max_seqlen_k: int,
        ) -> torch.Tensor:
        """
        b = batch size
        s = number of observations
        t = flashattention-compressed sequences of (batch, observations)
        h = heads
        d = dimension of embedding

        input: (bs, d)
        output: (bs, d)
        """

        q = self.q(query)
        k = self.k(key)
        v = self.v(value)

        q = einops.rearrange(q, 't (h d) -> t h d', h=self.n_heads)
        k = einops.rearrange(k, 't (h d) -> t h d', h=self.n_heads)
        v = einops.rearrange(v, 't (h d) -> t h d', h=self.n_heads)

        output = flash_attn_varlen_func(
            q = q, 
            k = k, 
            v = v, 
            cu_seqlens_q = cu_seqlens_q,
            cu_seqlens_k = cu_seqlens_k,
            max_seqlen_q = max_seqlen_q,
            max_seqlen_k = max_seqlen_k,
        )

        output = einops.rearrange(output, 't h d -> t (h d)')
        output = self.o(output)

        return output
    