import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Literal, List, Callable
import warnings
act_dict = {"relu": nn.ReLU(inplace=True), "ELU": nn.ELU(inplace=True), "silu": nn.SiLU(inplace=True), "softplus": nn.Softplus(), "softsign": nn.Softsign(), "softshrink": nn.Softshrink()}
from torch_geometric.nn import Sequential as PygSeq
from torch_geometric.nn.norm import GraphNorm as PygGN, InstanceNorm as PygIN
from torch import Tensor, BoolTensor
'''
x (B, N, d)
mask (B, N, 1)
'''
def maskedSum(x: Tensor, mask: BoolTensor, dim: int):
    '''
    mask=true: true elements
    '''
    return torch.sum(torch.where(mask, x, 0), dim=dim)

def maskedMean(x: Tensor, mask: BoolTensor, dim: int, gsize: Tensor = None):
    '''
    mask true elements
    '''
    if gsize is None:
        gsize = torch.sum(mask, dim=dim)
    return torch.sum(torch.where(mask, x, 0), dim=dim)/gsize

def maskedMax(x: Tensor, mask: BoolTensor, dim: int):
    return torch.max(torch.where(mask, x, -torch.inf), dim=dim)[0]

def maskedMin(x: Tensor, mask: BoolTensor, dim: int):
    return torch.min(torch.where(mask, x,  torch.inf), dim=dim)[0]

def maskednone(x: Tensor, mask: BoolTensor, dim: int):
    return x

reduce_dict =  {
            "sum": maskedSum,
            "mean": maskedMean,
            "max": maskedMax,
            "none": maskednone
        }


def expandbatch(x: Tensor, batch: Tensor):
    if batch is None:
        return x.flatten(0, 1), None
    else:
        R = x.shape[0]
        N = batch[-1] + 1
        offset = N * torch.arange(R, device=x.device).reshape(-1, 1)
        batch = batch.unsqueeze(0) + offset
        return x.flatten(0, 1), batch.flatten()


class NormMomentumScheduler:
    def __init__(self, mfunc: Callable, initmomentum: float, normtype=nn.BatchNorm1d) -> None:
        super().__init__()
        self.normtype = normtype
        self.mfunc = mfunc
        self.epoch = 0
        self.initmomentum = initmomentum

    def step(self, model: nn.Module):
        ratio = self.mfunc(self.epoch)
        if 1 - 1e-6 < ratio < 1 + 1e-6:
            return self.initmomentum
        curm = self.initmomentum * ratio
        self.epoch += 1
        for mod in model.modules():
            if type(mod) is self.normtype:
                mod.momentum = curm
        return curm


class NoneNorm(nn.Module):
    def __init__(self, dim=0, normparam=0) -> None:
        super().__init__()
        self.num_features = dim

    def forward(self, x):
        return x


class BatchNorm(nn.Module):
    def __init__(self, dim, normparam=0.1) -> None:
        super().__init__()
        self.num_features = dim
        self.norm = nn.BatchNorm1d(dim, momentum=normparam)

    def forward(self, x: Tensor):
        if x.dim() == 2:
            return self.norm(x)
        elif x.dim() > 2:
            shape = x.shape
            x = self.norm(x.flatten(0, -2)).reshape(shape)
            return x
        else:
            raise NotImplementedError


class LayerNorm(nn.Module):
    def __init__(self, dim, normparam=0.1) -> None:
        super().__init__()
        self.num_features = dim
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: Tensor):
        return self.norm(x)


class InstanceNorm(nn.Module):
    def __init__(self, dim, normparam=0.1) -> None:
        super().__init__()
        self.norm = PygIN(dim, momentum=normparam)
        self.num_features = dim

    def forward(self, x: Tensor):
        if x.dim() == 2:
            return self.norm(x)
        elif x.dim() > 2:
            shape = x.shape
            x = self.norm(x.flatten(0, -2)).reshape(shape)
            return x
        else:
            raise NotImplementedError


normdict = {"bn": BatchNorm, "ln": LayerNorm, "in": InstanceNorm, "none": NoneNorm}
basenormdict = {"bn": nn.BatchNorm1d, "ln": None, "in": PygIN, "gn": None, "none": None}


class NormedLinear(nn.Module):
    def __init__(self, in_channels: int, 
                 out_channels: int,
                 norm: Literal['none', 'bn', 'ln'] = 'ln'):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.norm_type = norm

        self.lin = nn.Linear(in_channels, out_channels)
        if norm == 'none':
            self.norm = nn.Identity()
        elif norm == 'bn':
            raise NotImplementedError()
        elif norm == 'ln':
            self.norm = nn.LayerNorm(out_channels)
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        if hasattr(self.norm, 'reset_parameters'):
            self.norm.reset_parameters()

    def forward(self, x):
        return self.norm(self.lin(x))

class NormedEmbedding(nn.Module):
    def __init__(self, num_emb: int, 
                 out_channels: int,
                 norm: Literal['none', 'bn', 'ln'] = 'ln'):
        super().__init__()
        self.num_emb = num_emb
        self.out_channels = out_channels
        self.norm_type = norm

        self.emb = nn.Embedding(num_emb, out_channels)
        if norm == 'none':
            self.norm = nn.Identity()
        elif norm == 'bn':
            raise NotImplementedError()
        elif norm == 'ln':
            self.norm = nn.LayerNorm(out_channels)
        
        self.reset_parameters()

    def reset_parameters(self):
        self.emb.reset_parameters()
        if hasattr(self.norm, 'reset_parameters'):
            self.norm.reset_parameters()

    def forward(self, x):
        return self.norm(self.emb(x.to(torch.int32)))

class MLP2(nn.Module):
    def __init__(self, in_channels: int, 
                 out_channels: int,
                 norm: Literal['none', 'bn', 'ln'] = 'ln'):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.norm_type = norm

        self.lin1 = nn.Linear(in_channels, 2*out_channels)
        self.lin2 = nn.Linear(2*out_channels, out_channels)
        if norm == 'none':
            self.norm = nn.Identity()
        elif norm == 'bn':
            raise NotImplementedError()
        elif norm == 'ln':
            self.norm = nn.LayerNorm(2*out_channels)
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        if hasattr(self.norm, 'reset_parameters'):
            self.norm.reset_parameters()

    def forward(self, x):
        return self.lin2(F.elu(self.norm(self.lin1(x))))
    

class Set2Set(nn.Module):
    def __init__(self, num_channels: int,
                 aggr: Literal['mean', 'sum', 'max', 
                               'min', 'mul'] = 'mean',
                 residual: bool = True):
        super().__init__()
        self.num_channels = num_channels
        self.aggr = aggr
        self.residual = residual

        self.lin1 = nn.Sequential(
            NormedLinear(num_channels, num_channels),
            nn.SiLU()
        )
        self.lin2 = NormedLinear(num_channels, num_channels)

    def forward(self, x):
        x1 = self.lin1(x)
        if self.aggr in ['min', 'max']:
            x1 = getattr(x1, self.aggr)(0).values
        else:
            x1 = getattr(x1, self.aggr)(0)

        x2 = self.lin2(x)

        return x1 * x2


class MLP(nn.Module):
    def __init__(self, indim: int, hiddim: int, outdim: int, numlayer: int=1, tailact: bool=True, dropout: float=0, norm: str="bn", activation: str="silu", tailbias=True, normparam: float=0.1) -> None:
        super().__init__()
        assert numlayer >= 0
        if isinstance(activation, str):
            activation = act_dict[activation]
        if isinstance(norm, str):
            norm = normdict[norm]
        if numlayer == 0:
            assert indim == outdim
            if norm != "none":
                warnings.warn("not equivalent to Identity")
                lin0 = nn.Sequential(norm(outdim, normparam))
            else:
                lin0 = nn.Sequential(NoneNorm())
        elif numlayer == 1:
            lin0 = nn.Sequential(nn.Linear(indim, outdim, bias=tailbias))
            if tailact:
                lin0.append(norm(outdim, normparam))
                if dropout > 0:
                    lin0.append(nn.Dropout(dropout, inplace=True))
                lin0.append(activation)
        else:
            lin0 = nn.Sequential(nn.Linear(hiddim, outdim, bias=tailbias))
            if tailact:
                lin0.append(norm(outdim, normparam))
                if dropout > 0:
                    lin0.append(nn.Dropout(dropout, inplace=True))
                lin0.append(activation)
            for _ in range(numlayer-2):
                lin0.insert(0, activation)
                if dropout > 0:
                    lin0.insert(0, nn.Dropout(dropout, inplace=True))
                lin0.insert(0, norm(hiddim, normparam))
                lin0.insert(0, nn.Linear(hiddim, hiddim))
            lin0.insert(0, activation)
            if dropout > 0:
                lin0.insert(0, nn.Dropout(dropout, inplace=True))
            lin0.insert(0, norm(hiddim, normparam))
            lin0.insert(0, nn.Linear(indim, hiddim))
        self.lin = lin0
        # self.reset_parameters()

    def forward(self, x: Tensor):
        return self.lin(x)

