from typing import Iterable
import torch
import torch.nn as nn
from torch.nn import functional as F
from metabeta.models.feedforward import MLP

# --------------------------------------------------------------------------
# Base Class
class Summarizer(nn.Module):
    ''' takes batch of sequential data x (batch, *, seq_len, d_data)
        and summarizes the sequences to h (batch, *, d_output) '''
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        raise NotImplementedError

# -----------------------------------------------------------------------------
# DeepSet
class InvariantBlock(nn.Module):
    def __init__(self,
                 d_model: int,
                 d_hidden: int | Iterable,
                 dropout: float = 0.01,
                 activation: str = 'Mish',
                 ):
        super().__init__()
        self.phi = MLP(d_input=d_model,
                       d_hidden=d_hidden,
                       d_output=d_model,
                       dropout=dropout,
                       activation=activation)
        self.rho = MLP(d_input=d_model,
                       d_hidden=d_hidden,
                       d_output=d_model,
                       dropout=dropout,
                       activation=activation)

    def pool(self, h: torch.Tensor, mask=None) -> torch.Tensor:
        if mask is None:
            out = h.mean(dim=-2)
        else:
            out = h * mask.unsqueeze(-1)
            out = out.sum(dim=-2)
            denominator = mask.sum(dim=-1, keepdim=True)
            out = out / (denominator.expand_as(out) + 1e-12)
        return out

    def forward(self, x, mask=None):
        # x (batch, seq_len, d_model)
        h = self.phi(x)
        h = self.pool(h, mask)
        h = self.rho(h)
        return h


class EquivariantBlock(nn.Module):
    '''
    Steps:
    1. Invariant module (my Deepset) (b, n, emb) -> (b, emb)
    2. tile -> (b, n, emb)
    3. concatenate with input -> (b, n, 2*emb)
    4. project through MLP -> (b, n, emb)
    5. add to initial input -> (b, n, emb)
    6. layer_norm -> (b, n, emb)
    '''
    def __init__(self,
                 d_model: int,
                 d_hidden: int | Iterable,
                 dropout: float = 0.01,
                 activation: str = 'Mish',
                 ):
        super().__init__()
        self.ib = InvariantBlock(d_model, d_hidden, dropout, activation)
        self.proj = MLP(d_input=2*d_model,
                        d_hidden=d_hidden,
                        d_output=d_model,
                        dropout=dropout,
                        activation=activation)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None):
        h = self.ib(x, mask)
        h = h.unsqueeze(-2).expand_as(x)
        h = torch.cat([x, h], dim=-1)
        h = x + self.proj(h)
        h = self.norm(h)
        return h


class DeepSet(Summarizer):
    def __init__(self,
                 d_model: int,
                 d_ff: int,
                 d_output: int,
                 d_input: int | None = None,
                 depth: int = 2,
                 n_blocks: int = 2,
                 dropout: float = 0.01,
                 activation: str = 'GELU',
                 **kwargs
                 ):
        super().__init__()
        # projectors
        if d_input:
            self.proj = nn.Linear(d_input, d_model)
            nn.init.zeros_(self.proj.bias)
        else:
            self.proj = nn.Identity()
        if d_model != d_output:
            self.out = nn.Linear(d_model, d_output)
            nn.init.zeros_(self.out.bias)
        else:
            self.out = nn.Identity()
        
        # deepset blocks
        blocks = []
        for _ in range(n_blocks):
            eb = EquivariantBlock(d_model, (d_ff,)*depth, dropout, activation)
            blocks += [eb]
        self.blocks = nn.ModuleList(blocks)
        self.ib = InvariantBlock(d_model, (d_ff,)*depth, dropout, activation)
        

    def forward(self, x, mask=None):
        h = self.proj(x)
        for eb in self.blocks:
            h = eb(h, mask)
        h = self.ib(h, mask)
        h = self.out(h)
        return h
    
    
# -----------------------------------------------------------------------------
# PoolFormer

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_proj: int,
        n_heads: int,
        d_input: int | tuple[int, int] | None = None,
        d_output: int | None = None,
        dropout: float = 0.01,
        use_bias: bool = True,
        share_heads: bool = False,
        weight_init: tuple[str, str] = ('kaiming', 'uniform'),
    ):
        super().__init__()
        # assumed inputs: query | query and key
        
        self.d_query = d_proj
        self.d_key = d_proj
        self.d_value = d_proj 
        self.d_input = d_input or (self.d_query, self.d_key)
        self.single = isinstance(d_input, int)
        self.d_output = d_output or self.d_value
        self.n_heads = n_heads
        self.dropout = dropout
        self.use_bias = use_bias
        self.share_heads = share_heads
        if self.share_heads:
            assert d_proj % n_heads == 0, 'projection dims must be divisible by heads'
            self.multiplier = 1
            self.divisor = n_heads
        else:
            self.multiplier = n_heads
            self.divisor = 1
        
        # build attention projectors
        if self.single:
            self.W_qkv = nn.Linear(
                in_features  = self.d_input,
                out_features = self.multiplier * (self.d_query + self.d_key + self.d_value),
                bias         = use_bias)
        else:
            
            self.W_q = nn.Linear(
                in_features  = self.d_input[0],
                out_features = self.multiplier * self.d_query,
                bias         = use_bias)
            self.W_kv = nn.Linear(
                in_features  = self.d_input[1],
                out_features = self.multiplier * (self.d_key + self.d_value),
                bias         = use_bias)
        
        # build output projector
        self.out_proj = nn.Linear(
                in_features  = self.multiplier * self.d_value,
                out_features = self.d_output,
                bias         = use_bias)
        
    
    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor | None = None,
                mask: torch.Tensor | None = None):
        # prepare
        assert (key is None and self.single) or (key is not None and not self.single), 'unexpected number of inputs'
        b, n_q, _ = query.shape
        n_k = n_q if key is None else key.shape[1]
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)
        
        # project input sequences
        if self.single:
            query, key, value = self.W_qkv(query).split(
                [self.multiplier * self.d_query,
                 self.multiplier * self.d_key,
                 self.multiplier * self.d_value], dim=-1)
        else:
            query = self.W_q(query)
            key, value = self.W_kv(key).split([self.multiplier * self.d_key,
                                               self.multiplier * self.d_value], dim=-1)
        
        # give separate dim for n_heads [bn(hz) -> bhnz]
        d_query = self.d_query // self.divisor
        d_key = self.d_key // self.divisor
        d_value = self.d_value // self.divisor
        query = query.view(b, n_q, self.n_heads, d_query).transpose(1, 2)
        key = key.view(b, n_k, self.n_heads, d_key).transpose(1, 2)
        value = value.view(b, n_k, self.n_heads, d_value).transpose(1, 2)
        
        # calculate attention outputs
        dropout_p = self.dropout if self.training else 0.
        attn_outputs = F.scaled_dot_product_attention(
            query, key, value,
            dropout_p=dropout_p,
            attn_mask=mask, 
            ).transpose(1, 2).reshape(b, n_q, -1)
        
        # project out
        output = self.out_proj(attn_outputs)
        return output
        

class MultiheadAttentionBlock(nn.Module):
    def __init__(self,
                 d_model: int,
                 d_input: int | tuple[int, int],
                 d_hidden: int | Iterable = (128, 128),
                 n_heads: int = 4,
                 activation: str = 'GELU',
                 dropout: float = 0.01,
                 eps: float = 1e-3,
                 ):
        super().__init__()
        first = d_input if isinstance(d_input, int) else d_input[0]
        
        # attention: custom implementation with separate heads
        self.mha = MultiHeadAttention(
            d_proj=d_model,
            d_input=d_input,
            d_output=d_model,
            n_heads=n_heads,
            dropout=dropout,
            use_bias=True)    
        
        # projection
        if first != d_model:
            self.proj = nn.Linear(first, d_model)
        else:
            self.proj = nn.Identity()
        self.mlp = MLP(
            d_input=d_model,
            d_hidden=d_hidden,
            d_output=d_model,
            activation=activation,
            dropout=dropout)
        
        # layer norms
        self.ln0 = nn.LayerNorm(d_model, eps=eps)
        self.ln1 = nn.LayerNorm(d_model, eps=eps)
            

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor | None = None,
                mask: torch.Tensor | None = None):
        h = self.mha(query, key, mask=mask)
        h = h + self.proj(query)
        h = self.ln0(h)
        h = h + self.mlp(h)
        h = self.ln1(h)
        return h


class PoolingBlock(nn.Module):
    def __init__(self, 
                 d_model: int,
                 d_hidden: int | Iterable = (128, 128),
                 n_heads: int = 4,
                 dropout: float = 0.01,
                 activation: str = 'GELU',
                 n_seeds: int = 1,
                 ):
        super().__init__()
        d_output = d_hidden[-1]
        self.mlp = MLP(
            d_input=d_model,
            d_hidden=d_hidden,
            activation=activation,
            dropout=dropout)
        self.mab = MultiheadAttentionBlock(
            d_model=d_model,
            d_input=(d_output, d_output),
            d_hidden=d_hidden,
            n_heads=n_heads,
            activation=activation,
            dropout=dropout)
        self.s = nn.Parameter(torch.Tensor(1, n_seeds, d_output))
        nn.init.xavier_uniform_(self.s)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None):
        h = self.mlp(x)
        s = self.s.repeat(x.size(0), 1, 1)
        out = self.mab(s, h, mask=mask).squeeze(-2)
        return out


class PoolFormer(Summarizer):
    def __init__(self,
                 d_model: int,
                 d_ff: int,
                 d_output: int,
                 d_input: int | None = None,
                 depth: int = 2,
                 n_heads: int = 4,
                 n_blocks: int = 2,
                 dropout: float = 0.01,
                 activation: str = 'GELU',
                 sparse: bool = False,
                 **kwargs
                 ):
        super().__init__()
        self.d_input = d_input or d_model
        self.d_output = d_output
        self.sparse = sparse
        
        # Multihead Attention Blocks
        blocks = []
        if sparse:
            self.emb = nn.Linear(d_input, d_model)
            for i in range(n_blocks):
                layer = nn.TransformerEncoderLayer(
                    d_model=d_model,
                    dim_feedforward=d_ff,
                    nhead=n_heads,
                    dropout=dropout,
                    layer_norm_eps=1e-3,
                    batch_first=True,
                    activation=activation.lower())
                blocks += [layer]
        else:
            for i in range(n_blocks):
                mab = MultiheadAttentionBlock(
                    d_model=d_model,
                    d_input=(d_input if i == 0 else d_model),
                    d_hidden=(d_ff,)*depth,
                    n_heads=n_heads,
                    dropout=dropout,
                    eps=1e-3,
                    activation=activation)
                blocks += [mab]
        self.blocks = nn.ModuleList(blocks)
        
        # Pooling Block
        if not sparse:
            self.pma = PoolingBlock(
                d_model=d_model,
                d_hidden=(d_ff,)*depth,
                n_heads=n_heads,
                dropout=dropout,
                activation=activation)
        
        # Output Projection
        if d_model != d_output:
            self.out = nn.Linear(d_model, d_output)
        else:
            self.out = nn.Identity()
        
    def prepare(self,
                x: torch.Tensor,
                mask: torch.Tensor | None = None
                ) -> tuple[torch.Tensor, torch.Tensor]:
        # if inputs are 4d, unify batch and group dimension
        if x.dim() > 3:
            b, m, n, d = x.shape
            x = x.view(b*m, n, d)
        # do the same with the optional mask
        if mask is not None:
            mask = mask.bool()        
            if mask.dim() > 2:
                mask = mask.view(b*m, n)
        return x, mask
    
    def postpare(self,
                 h: torch.Tensor,
                 x: torch.Tensor) -> torch.Tensor:
        # put everything back together
        if x.dim() == 3:
            return h
        b, m, _, _ = x.shape
        out = h.view(b, m, -1)
        return out
    
    def forward(self, x, mask=None):
        # assumes x.dim() == mask.dim()-1
        h, mask = self.prepare(x, mask=mask)
        if self.sparse:
            h = self.emb(h)
            for block in self.blocks:
                h = block(h, src_key_padding_mask=(~mask if mask is not None else None))
            h[h.isnan()] = 0
            if mask is not None:
                counts = mask.sum(-1, keepdim=True) + 1e-12
                h = (h * mask.unsqueeze(-1)).sum(-2) / counts # masked average pooling
            else:
                h = h.mean(-2)
        else:
            for block in self.blocks:
                h = block(h, mask=mask)
            h = self.pma(h, mask=mask)
        h = self.out(h)
        out = self.postpare(h, x)
        return out



    
# =============================================================================
if __name__ == "__main__":
    
    # prepare data
    b, m, n, d = 8, 30, 50, 3
    d_model = 64
    d_ff = 128
    d_output = 32
    n_heads = 4
    dropout = 0.05
    x = torch.randn(b, m, n, d)
    mask = torch.randint(0, 2, (b, m, n)).bool()
    x[~mask] = 0.
    
    # -------------------------------------------------------------------------
    # deepset
    model = DeepSet(d_model=d_model, d_ff=d_model, d_input=d, d_output=d_model, n_blocks=2)
    output = model(x, mask)
    
    # invariance
    model.eval()
    perm = torch.randperm(n)
    out1 = model(x, mask)
    out2 = model(x[:, :, perm], mask[:, :, perm])
    assert torch.allclose(out1, out2, atol=1e-5), "deepset not permutation invariant"
    
    # -------------------------------------------------------------------------
    # poolformer
    model = PoolFormer(d_model=d_model, d_ff=d_model, d_input=d, d_output=d_output)
    output = model(x, mask)
    
    # invariance
    model.eval()
    perm = torch.randperm(n)
    out1 = model(x, mask)
    out2 = model(x[:, :, perm], mask[:, :, perm])
    assert torch.allclose(out1, out2, atol=1e-5), "poolformer not permutation invariant"
    
    # # speed profiling
    # from metabeta.utils import profile
    # mha1 = MultiHeadAttention_(n_heads=4, d_proj=32, d_input=[d,d])
    # mha2 = MultiHeadAttention(n_heads=4, d_proj=32, d_input=[d,d])
    # x_ = x.view(b*m,n,d)
    
    # profile(mha1, (x_,x_))
    # profile(mha2, (x_,x_))
    
    