import torch
import numpy as np
from torch import nn
from einops.layers.torch import Rearrange
import pdb

activation_functions = {
    'ReLU': nn.ReLU(),
    'Sigmoid': nn.Sigmoid(),
    'Tanh': nn.Tanh(),
    'LeakyReLU': nn.LeakyReLU(negative_slope=0.01),  
    'ELU': nn.ELU(alpha=1.0),  
    'PReLU': nn.PReLU(num_parameters=1, init=0.25),
    'Mish' : nn.Mish()
}

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, activation='ELU'):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            activation_functions[activation],
            nn.Linear(hidden_dim, dim)
        )
    def forward(self, x):
        return self.net(x)


class MixerBlock(nn.Module):

    def __init__(self, dim_xi, dim_psi, psi_hidden, xi_hidden,activation='ELU'):
        super().__init__()

        self.token_mix = nn.Sequential(
            nn.LayerNorm(dim_xi),
            Rearrange('b n d -> b d n'),
            FeedForward(dim_psi, psi_hidden,activation=activation),
            Rearrange('b d n -> b n d')
        )

        self.channel_mix = nn.Sequential(
            nn.LayerNorm(dim_xi),
            FeedForward(dim_xi, xi_hidden,activation=activation),
        )

    def forward(self, x):
        # input shape [bs, dim_psi, dim_xi]
        x = x + self.token_mix(x)
        x = x + self.channel_mix(x)

        return x

class KronMixier(nn.Module):
    def __init__(self,dim_c, dim_x, psi_hidden, xi_hidden,activation='ELU',depth=6):
        super().__init__()
        self.MLP_Mixer = nn.Sequential(*[MixerBlock(dim_xi=dim_c+1, dim_psi=dim_x, psi_hidden=psi_hidden, 
                                        xi_hidden=xi_hidden) for _ in range(depth)])

        self.data_dim = dim_x
        self.conditional_dim = dim_c
    
    def forward(self,psi_xi_xi):
        ""
        psi_xi = psi_xi_xi[:,:2]
        xi     = psi_xi_xi[:,2:]
        input = torch.einsum('bi, bj -> bij',psi_xi,xi)
        
        return self.MLP_Mixer(input)



if __name__ == "__main__":
    img = torch.ones([100, 2, 2])
    model = MixerBlock(dim_xi=2, dim_psi=2, psi_hidden=4, xi_hidden=8)

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
    print('Trainable Parameters: %.3fM' % parameters)

    out_img = model(img)

    print("Shape of out :", out_img.shape)  # [B, in_channels, image_size, image_size]