import torch
import torch.nn.functional as F
from torch import nn, einsum
import models.polynomials as poly
from einops import rearrange


POLYNOMIAL = {
    "chebyshev": poly.ChebyshevConv, 
}

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# classes
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

# attention
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x, **kwargs):
        return self.net(x)


class CheAtt(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 16,
        dropout = 0.,
        K = 5,
        polynomial = 'chebyshev'
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)
        self.chebnet = poly.PolyConvFrame(POLYNOMIAL[polynomial], depth = K, alpha = 1.0, beta=0.3, fixed = False)

    def forward(self, x):
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = sim.softmax(dim = -1)
        out = self.chebnet(x=v, adj=attn) 

        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        tmp = [attn, out]
        return self.to_out(out), tmp

class Transformer_CheAtt(nn.Module):
    def __init__(self, num_tokens, dim, n_cont, depth, heads, dim_head, attn_dropout, ff_dropout, K=5, polynomial="chebyshev"):
        super().__init__()
        self.n_cont = n_cont
        self.num_tokens = num_tokens
        self.embeds = nn.Embedding(num_tokens+1, dim)
        self.cont_embeds = nn.Linear(n_cont, dim*n_cont)
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, CheAtt(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout, K = K, polynomial=polynomial)),
                PreNorm(dim, FeedForward(dim, dropout = ff_dropout)),
            ]))

    def forward(self, categ_x, cont_x, return_attn = False):
        categ_x = self.embeds(categ_x.long())
        
        if self.n_cont > 0:
            cont_x = self.cont_embeds(cont_x.float())
            cont_x = cont_x.view(cont_x.size(0), self.n_cont, -1)
        else:
            cont_x = torch.tensor([], device=categ_x.device)

        x = torch.concat([categ_x, cont_x], dim=1)
        post_softmax_attns = []

        for attn, ff in self.layers:
            attn_out, post_softmax_attn = attn(x)
            post_softmax_attns.append(post_softmax_attn)

            x = x + attn_out
            x = ff(x) + x

        if not return_attn:
            return x

        return x, post_softmax_attns
    
# mlp
class MLP(nn.Module):
    def __init__(self, dims, act = None):
        super().__init__()
        dims_pairs = list(zip(dims[:-1], dims[1:]))
        layers = []
        for ind, (dim_in, dim_out) in enumerate(dims_pairs):
            is_last = ind >= (len(dims_pairs) - 1)
            linear = nn.Linear(dim_in, dim_out)
            layers.append(linear)

            if is_last:
                continue

            act = default(act, nn.ReLU())
            layers.append(act)

        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)
