import torch
from torch import nn
from torch_geometric.utils import scatter
import numpy as np
from math import ceil
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

device = "cuda"

def zero_one(x, interval=None):
    if interval is None:
        interval = [0, 1]
    start, end = interval
    length = end - start
    maxm, _ = torch.max(x, dim=0)
    minm, _ = torch.min(x, dim=0)
    scale = maxm - minm
    abnormal = scale == 0
    normal = scale != 0
    x[:, normal] = ((x[:, normal] - minm[normal]) / scale[normal]) * length + start
    x[:, abnormal] = end
    return x

class Attention(nn.Module):  # external attention
    def __init__(self, cof=32, num_heads=8, attn_drop=0., proj_drop=0.):
        super().__init__()
        dim = 1
        self.num_heads = num_heads
        self.coef = cof
        self.trans_dims = nn.Linear(dim, self.coef)
        self.linear_0 = nn.Linear(self.coef // num_heads, self.coef // num_heads)
        self.linear_1 = nn.Linear(self.coef // num_heads, self.coef // num_heads)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(self.coef, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        x_tmp = x.unsqueeze(-1)
        B, N, C = x_tmp.shape

        x_tmp = self.trans_dims(x_tmp)  # B, N, C
        x_tmp = x_tmp.view(B, N, self.num_heads, -1).permute(0, 2, 1, 3)

        attn = self.linear_0(x_tmp)
        attn = attn.softmax(dim=-2)
        attn = attn / (1e-9 + attn.sum(dim=-1, keepdim=True))
        attn = self.attn_drop(attn)
        x_tmp = self.linear_1(attn).permute(0, 2, 1, 3).reshape(B, N, -1)
        x_tmp = self.proj(x_tmp)
        x_tmp = self.proj_drop(x_tmp).squeeze(-1)
        return x_tmp + x


class GaussianType(torch.nn.Module):

    def __init__(self,
                 in_channel: int,
                 num_mf: int,
                 values_intervals: list,
                 fix: bool = False,
                 close: bool = True,
                 cross: float = 0.7):
        '''
        Initiate Gaussian-Type membership function
        :param in_channel: Input dimension
        :param num_mf: Number of membership function
        :param values_intervals: Domain of definition of membership functions
        :param fix: Whether to make the parameters of the membership function trainable
        :param close: Let the peak of the membership function fall on the boundary of the defined domain
        :param cross: The cross point of two membership function
        '''

        super().__init__()
        self.num_mf = num_mf
        self.intervals = values_intervals
        self.in_channel = in_channel
        self.close = close
        self.cross = cross
        self.fix = fix

        start, end = values_intervals
        if close:
            intervals = (end - start) / (num_mf + 1)
            start += intervals
            end -= intervals
        C = torch.linspace(start=start, end=end, steps=num_mf, requires_grad=False)
        if num_mf > 1:
            sigma = torch.pow(C[1] - C[0], 2) * (1 / (-8 * np.log(cross)))
        else:
            sigma = 0.1225
        C = torch.tile(torch.linspace(start=start, end=end, steps=num_mf).view(1, -1), (in_channel, 1)).unsqueeze(0)
        Sigma = torch.tile((sigma * torch.ones((num_mf,))).view(1, -1), (in_channel, 1)).unsqueeze(0)

        self.register_parameter("C", torch.nn.Parameter(C))
        self.register_parameter("Sigma", torch.nn.Parameter(Sigma))
        if fix:
            self.C.requires_grad = False
            self.Sigma.requires_grad = False

    def reset(self):
        if not self.fix:
            start, end = self.intervals
            if self.close:
                intervals = (end - start) / (self.num_mf + 1)
                start += intervals
                end -= intervals
            C = torch.linspace(start=start, end=end, steps=self.num_mf, requires_grad=False)
            if self.num_mf > 1:
                sigma = torch.pow(C[1] - C[0], 2) * (1 / (-8 * np.log(self.cross)))
            else:
                sigma = 0.1225
            C = torch.tile(torch.linspace(start=start, end=end, steps=self.num_mf).view(1, -1),
                           (self.in_channel, 1)).unsqueeze(0)
            Sigma = torch.tile((sigma * torch.ones((self.num_mf,))).view(1, -1), (self.in_channel, 1)).unsqueeze(0)
            self.register_parameter("C", torch.nn.Parameter(C))
            self.register_parameter("Sigma", torch.nn.Parameter(Sigma))
            if self.fix:
                self.C.requires_grad = False
                self.Sigma.requires_grad = False

    def forward(self, x):
        x = torch.exp(- torch.pow((x.unsqueeze(-1) - self.C), 2) / (2 * self.Sigma))
        return x


class Fuzzier(torch.nn.Module):

    def __init__(self,
                 in_channels: int,
                 num_mf: int,
                 value_intervals: list,
                 fix: bool = False,
                 cross: float = 0.7):
        super().__init__()
        self.MFs = GaussianType(in_channels, num_mf, value_intervals, fix, cross=cross)

    def forward(self, x):
        singleton_x = self.MFs(x)
        return singleton_x

    def reset(self):
        self.MFs.reset()


class Ruler(torch.nn.Module):

    def __init__(self,
                 in_channel: int,
                 out_channel: int,
                 n_mf: int,
                 order: int,
                 window_size: int = 3,
                 stride_size: int = 3,
                 norm: bool = False,
                 A_P2: bool = True,
                 refine_ratio: float = 1.,
                 refiner: str = "pool",
                 attention: bool = True,
                 residual: bool = True,
                 ):

        """
        :param in_channel: Input dimension
        :param out_channel: Output dimension
        :param n_mf: Number of membership function
        :param order: The order of defuzzification
        :param norm: Normalization defuzzification result
        :param window_size: The slide windows size
        :param stride_size: The slide stride step size
        :param A_P2: Use part-2 of FL-GNN-A
        :param refiner: "pool: MaxPooling1D", "top-k: Top K"
        :param refine_ratio: Firing  Strength refine ratio
        :param attention: Whether apply attention layer
        """
        assert order in {0, 1}
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.n_mf = n_mf
        self.num_blocks = torch.empty(size=(1, in_channel, n_mf)).unfold(1, window_size, stride_size).shape[1]
        self.rules_size = self.num_blocks * pow(n_mf, window_size)
        self.rules_size_refine = int(ceil(refine_ratio * self.rules_size))
        self.norm = norm
        self.order = order
        self.stride = stride_size
        self.window = window_size
        self.A_P2 = A_P2
        self.refiner_ratio = refine_ratio
        self.refiner_name = refiner
        self.refiner = None
        self.residual = None
        self.attention = None
        self.outer = None
        self.extract_index = torch.arange(start=self.rules_size,
                                          end=self.rules_size + self.in_channel,
                                          step=1,
                                          device=device).unsqueeze(0)
        self.init_()
        if attention:
            self.attention = nn.Sequential(Attention(cof=24, num_heads=4),
                                           torch.nn.BatchNorm1d(num_features=self.rules_size),
                                           torch.nn.GELU())
        if residual:
            self.residual = nn.Linear(in_features=self.in_channel, out_features=self.out_channel)

    @staticmethod
    def SourceSubIndex(N: int, M: int):
        subindex = [[0] * N for _ in range(M ** N)]
        for i in range(1, M ** N):
            subindex[i] = subindex[i - 1][::]
            for j in range(N):
                if subindex[i][j] < M - 1:
                    subindex[i][j] += 1
                    break
                else:
                    subindex[i][j] = 0
        return torch.add(torch.LongTensor(subindex), torch.arange(start=0, end=M * N, step=M)).ravel().to(device)

    def init_(self):
        if self.A_P2:
            if self.refiner_name == "pool":
                self.refiner = torch.nn.AdaptiveMaxPool1d(output_size=self.rules_size_refine)
                if self.order == 0:
                    coe_defuzzifize = torch.randn(size=(self.rules_size_refine, self.out_channel))
                    self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
                else:
                    coe_defuzzifize = torch.randn(size=(self.in_channel + self.rules_size_refine, self.out_channel))
                    self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
            else:
                self.refiner = torch.topk
                if self.order == 0:
                    coe_defuzzifize = torch.randn(size=(self.rules_size, self.out_channel))
                    self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
                else:
                    coe_defuzzifize = torch.randn(size=(self.in_channel + self.rules_size, self.out_channel))
                    self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))

            self.outer = nn.Sequential(torch.nn.BatchNorm1d(num_features=self.out_channel), torch.nn.GELU())
        else:
            if self.order == 0:
                coe_defuzzifize = torch.randn(
                    torch.empty(size=(1, self.out_channel * self.rules_size)))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
            else:
                coe_defuzzifize = torch.randn(size=(self.in_channel + 1, self.out_channel * self.rules_size))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))

    def full_expand(self, x):
        N, B, D, M = x.shape
        # x = -torch.log_(x)
        tmp = x[:, :, 0, :]
        for n in range(1, D):
            # 10.1109/TFUZZ.2020.2992856
            # next_layer = torch.tile(next_layer.unsqueeze(-2), (tmp.size(-1), 1))
            # tmp = torch.tile(tmp.unsqueeze(-1), (1, next_layer.size(-1)))
            # tmp = tmp + next_layer
            tmp = torch.mul(tmp.unsqueeze(-2), x[:, :, n, :].unsqueeze(-1))
            tmp = tmp.view(N, B, -1)
        # tmp =  1 / tmp
        return tmp

    def forward(self, x, x_f, edge_index):

        N, D, M = x_f.shape

        # μ_uv = t(μ(u), μ(v))
        # x_s = (x_f[edge_index[0]] * x_f[edge_index[1]])

        # μ_ui = S(μ_uv), aggregate neighbor information
        mu_AggregationG = scatter(x_f[edge_index[0]], edge_index[1], dim=0, reduce="mean")

        res = torch.transpose(mu_AggregationG.unfold(1, self.window, self.stride), -1, -2)
        res = self.full_expand(res)

        # normalization factor firing strength
        norm_fac = torch.sum(res, dim=-1, keepdim=True) if self.norm else 1  # (N, R) -> (N, 1)
        res = (res / norm_fac).view(N, -1)

        if self.attention is not None:
            res = self.attention(res)

        if self.A_P2:
            if self.refiner_name == "pool":
                res = self.refiner(res)
                if self.order == 1:
                    out = torch.cat([x, res], dim=1)
                else:
                    out = res
                coe_defuzzifize = self.coe_defuzzifize
            else:
                res, index = self.refiner(res, k=self.rules_size_refine, dim=1, largest=True)
                if self.order == 1:
                    extract_index = self.extract_index.expand((N, self.in_channel))
                    index = torch.cat([extract_index, index], dim=1)
                    out = torch.cat([x, res], dim=1)
                else:
                    out = res
                coe_defuzzifize = torch.index_select(self.coe_defuzzifize, dim=0, index=index.view(-1)).view(N,
                                                                                                             -1,
                                                                                                             self.out_channel)

            if self.refiner_name == "pool":
                out = torch.matmul(out, coe_defuzzifize)
                if self.residual is not None:
                    x = self.residual(x)
                    out = out + x
            else:
                out = out.unsqueeze(-1)
                out = torch.mul(out, coe_defuzzifize).sum(dim=1)
                if self.residual is not None:
                    x = self.residual(x)
                    out = out + x
            out = self.outer(out)
        else:
            # standard defuzzification
            out = torch.ones(size=(N, 1)).to(device)
            if self.order == 1:
                out = torch.concatenate([x, out], dim=1)
            out = torch.matmul(out, self.coe_defuzzifize)  # (N, D) @ (D, OUT * R) -> (N, OUT * R)
            out = out.view(N, self.rules_size, self.out_channel)  # (N, OUT * R)  -> (N, R , OUT)
            out = torch.sum(res.unsqueeze(-1) * out, dim=1)
        return out


class FLGnnConv(torch.nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 num_mf: int = 3,
                 fix_mf: bool = False,
                 norm: bool = True,
                 method: str = "mad",
                 value_intervals=None,
                 order: int = 1,
                 windows_size: int = 3,
                 stride_size: int = 3,
                 refine_ratio: float = 1.,
                 refiner: str = "pool",
                 A_P2: bool = True,
                 residual: bool = True,
                 attention: bool = True,
                 cross: float = 0.7):
        super().__init__()

        self.value_intervals = [0, 1] if value_intervals is None else value_intervals

        self.fuzzier = Fuzzier(in_channels=in_channels,
                               num_mf=num_mf,
                               value_intervals=self.value_intervals,
                               fix=fix_mf,
                               cross=cross)

        self.ruler = Ruler(in_channel=in_channels,
                           out_channel=out_channels,
                           n_mf=num_mf,
                           norm=norm,
                           residual=residual,
                           order=order,
                           window_size=windows_size,
                           stride_size=stride_size,
                           A_P2=A_P2,
                           refiner=refiner,
                           refine_ratio=refine_ratio,
                           attention=attention)

    def reset(self):
        self.batch_norm.weight.data.fill_(1)
        self.batch_norm.bias.data.zero_()
        self.fuzzier.reset()
        self.ruler.init()

    def forward(self, x, edge_index):
        x = zero_one(x, self.value_intervals)
        x_f = self.fuzzier(x)
        x = self.ruler(x, x_f, edge_index)
        return x


class NodeEmbed(torch.nn.Module):

    def __init__(self, out_feature: int):
        super().__init__()
        self.edge_encoder = BondEncoder(emb_dim=out_feature)
        self.node_encoder = AtomEncoder(emb_dim=out_feature)
        self.ban = torch.nn.BatchNorm1d(num_features=out_feature)
        self.act = torch.nn.GELU()

    def forward(self, x, edge_attr, edge_index):
        x = self.node_encoder(x)
        e_out = self.edge_encoder(edge_attr)
        e_out = scatter(e_out, index=edge_index[0], dim=0, dim_size=x.shape[0], reduce="mean")
        x = x + e_out

        x = self.ban(x)
        x = self.act(x)

        return x


if __name__ == '__main__':
    pass
