import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch_geometric.utils import scatter, degree
from torch_scatter import scatter_max, scatter_min, scatter_add
import numpy as np
from math import ceil
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

device = "cuda"
counter = []

def zero_one(x, interval=None):
    if interval is None:
        interval = [0, 1]
    start, end = interval
    length = end - start
    maxm, _ = torch.max(x, dim=0)
    minm, _ = torch.min(x, dim=0)
    scale = maxm - minm
    abnormal = scale == 0
    normal = scale != 0
    x[:, normal] = ((x[:, normal] - minm[normal]) / scale[normal]) * length + start
    x[:, abnormal] = end
    return x

class GaussianMemberShap(torch.nn.Module):

    def __init__(self,
                 in_channel: int,
                 num_mf: int,
                 values_intervals: list,
                 fix: bool = False,
                 close: bool = False,
                 cross: float = 0.6):
        super().__init__()
        self.num_mf = num_mf
        self.intervals = values_intervals
        self.in_channel = in_channel
        self.close = close
        self.cross = cross
        self.fix = fix

        start, end = values_intervals
        if close:
            intervals = (end - start) / (num_mf + 1)
            start += intervals
            end -= intervals
        C = torch.linspace(start=start, end=end, steps=num_mf, requires_grad=False)
        sigma = torch.pow(C[1] - C[0], 2) * (1 / (-8 * np.log(cross)))
        C = torch.tile(torch.linspace(start=start, end=end, steps=num_mf).view(1, -1), (in_channel, 1)).unsqueeze(0)
        Sigma = torch.tile((sigma * torch.ones((num_mf,))).view(1, -1), (in_channel, 1)).unsqueeze(0)

        self.register_parameter("C", torch.nn.Parameter(C))
        self.register_parameter("Sigma", torch.nn.Parameter(Sigma))
        if fix:
            self.C.requires_grad = False
            self.Sigma.requires_grad = False

    def reset(self):
        if not self.fix:
            start, end = self.intervals
            if self.close:
                intervals = (end - start) / (self.num_mf + 1)
                start += intervals
                end -= intervals
            C = torch.linspace(start=start, end=end, steps=self.num_mf, requires_grad=False)
            sigma = torch.pow(C[1] - C[0], 2) * (1 / (-8 * np.log(self.cross)))
            C = torch.tile(torch.linspace(start=start, end=end, steps=self.num_mf).view(1, -1),
                           (self.in_channel, 1)).unsqueeze(0)
            Sigma = torch.tile((sigma * torch.ones((self.num_mf,))).view(1, -1), (self.in_channel, 1)).unsqueeze(0)
            self.register_parameter("C", torch.nn.Parameter(C))
            self.register_parameter("Sigma", torch.nn.Parameter(Sigma))
            if self.fix:
                self.C.requires_grad = False
                self.Sigma.requires_grad = False

    def forward(self, x):
        x = x.unsqueeze(-1)
        x = torch.exp(- torch.pow((x - self.C), 2) / (2 * self.Sigma))  # gaussian mf
        return x


class Fuzzier(torch.nn.Module):

    def __init__(self,
                 in_channels: int,
                 num_mf: int,
                 value_intervals: list,
                 fix: bool = False,
                 cross: float = 0.7):
        super().__init__()
        self.MFs = GaussianMemberShap(in_channels, num_mf, value_intervals, fix, cross=cross)

    def forward(self, x):
        singleton_x = self.MFs(x)
        return singleton_x

    def reset(self):
        self.MFs.reset()


class Ruler(torch.nn.Module):

    def __init__(self,
                 in_channel: int,
                 out_channel: int,
                 n_mf: int,
                 order: int,
                 window_size: int = 3,
                 stride_size: int = 3,
                 norm: bool = False,
                 concat: bool = True,
                 method: str = "mad",
                 extract_ratio: int = 1,
                 ):

        """
        :param in_channel: input features
        :param out_channel: output features
        :param n_mf: number of membership function
        :param order: the order of defuzzification
        :param norm: normalization defuzzification result
        :param window_size: the slide windows size
        :param stride_size: the slide stride size
        :param concat: concat the firing strength with input
        """

        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.n_mf = n_mf
        self.num_blocks = int(ceil((in_channel - window_size) / stride_size)) + (window_size >= stride_size)
        self.rules_size = self.num_blocks * pow(n_mf, window_size)
        self.padding = (stride_size * (self.num_blocks - 1) + window_size) % in_channel
        self.norm = norm
        self.order = order
        self.method = method
        self.stride = stride_size
        self.window = window_size
        self.concat = concat
        self.extract_ratio = extract_ratio
        self.extractor = None
        self.recorder = [[], []]  # record the firing strength of each rule for each input
        self.subIndex = self.SourceSubIndex(window_size, n_mf)
        self.reset()

        self.ban = torch.nn.BatchNorm1d(num_features=out_channel)
        self.act = torch.nn.LeakyReLU()

    @staticmethod
    def SourceSubIndex(N: int, M: int):
        subindex = [[0] * N for _ in range(M ** N)]
        for i in range(1, M ** N):
            subindex[i] = subindex[i - 1][::]
            for j in range(N):
                if subindex[i][j] < M - 1:
                    subindex[i][j] += 1
                    break
                else:
                    subindex[i][j] = 0
        return torch.add(torch.LongTensor(subindex), torch.arange(start=0, end=M * N, step=M)).ravel().to(device)

    def reset(self):

        if self.order == 0:
            if not self.concat:
                coe_defuzzifize = torch.nn.init.xavier_normal_(
                    torch.empty(size=(1, self.out_channel * self.rules_size)))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
            else:
                if self.extract_ratio:
                    self.rules_size = int(ceil(self.rules_size * self.extract_ratio))
                    self.extractor = torch.nn.AdaptiveAvgPool1d(self.rules_size)
                coe_defuzzifize = torch.randn(size=(self.rules_size, self.out_channel))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
        elif self.order == 1:
            if not self.concat:
                coe_defuzzifize = torch.randn(size=(self.in_channel + 1, self.out_channel * self.rules_size))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
            else:
                if self.extract_ratio:
                    self.rules_size = int(ceil(self.rules_size * self.extract_ratio))
                    self.extractor = torch.nn.AdaptiveAvgPool1d(self.rules_size)
                coe_defuzzifize = torch.randn(size=(self.in_channel + self.rules_size, self.out_channel))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
        else:
            raise "order must be either 0 or 1."
        self.coe_defuzzifize.to(device)
        self.recorder = [[], []]

    def full_expand(self, x):
        N, B, D, M = x.shape
        # x = -torch.log_(x)
        tmp = x[:, :, 0, :]
        for n in range(1, D):
            # 10.1109/TFUZZ.2020.2992856
            # next_layer = torch.tile(next_layer.unsqueeze(-2), (tmp.size(-1), 1))
            # tmp = torch.tile(tmp.unsqueeze(-1), (1, next_layer.size(-1)))
            # tmp = tmp + next_layer

            tmp = torch.mul(tmp.unsqueeze(-2), x[:, :, n, :].unsqueeze(-1))
            tmp = tmp.view(N, B, -1)

        # tmp =  1 / tmp
        return tmp

    def forward(self, x, x_f, edge_index):

        N, D, M = x_f.shape

        # calculate firing strength for each rules

        # μ_uv = t(μ(u), μ(v))
        x_s = x_f[edge_index[0]] * x_f[edge_index[1]]


        # mask the singleton which membership degree below threshold
        # mu_ConnG = torch.where(x_s > x_t, x_t, x_s)

        # μ_ui = S(μ_uv), aggregate neighbor information
        degree_ = degree(edge_index[1], num_nodes=N)
        mu_AggregationG = scatter_add(x_s, edge_index[1], dim=0) / degree_[:, None, None]

        # mu_AggregationG = scatter(x_s, edge_index[1], dim=0, reduce="mul")

        # mu_AggregationG = scatter_add(x_s, edge_index[1], dim=0)
        # mu_AggregationG, _ = scatter_max(mu_ConnG, edge_index[1], dim=0)
        # mu_AggregationG, _ = scatter_min(mu_ConnG, edge_index[1], dim=0)

        # μ_u = T(μ_ui), generate firing strength of each rule
        # heatMap(torch.mean(mu_AggregationG, dim=0, keepdim=False))

        # padding aggregation
        mu_AggregationG = torch.concat([mu_AggregationG, torch.ones(size=(N, self.padding, self.n_mf), device=device)],
                                       dim=-2)

        res = torch.zeros(size=(N, self.num_blocks, self.window, self.n_mf), device=device)
        n = 0
        # mu_AggregationG = torch.log_(mu_AggregationG)
        for block in range(self.num_blocks):
            res[:, block, :, :] = mu_AggregationG[:, n:n + self.window, :]
            n += self.stride

        N, B, D, M = res.shape

        # res = self.full_expand(res)
        res = torch.prod(torch.index_select(res.view(N, B, -1), index=self.subIndex, dim=-1).view(N, B, -1, D), dim=-1)

        # mask low firing strength
        # tmp = torch.where(tmp < 0.05, 0, tmp)

        # normalization factor firing strength
        norm_fac = torch.sum(res, dim=-1, keepdim=True) if self.norm else 1  # (N, R) -> (N, 1)
        res = (res / norm_fac).view(N, -1)

        # here to record firing strength degree for each rule
        # fs = res.ravel().detach().cpu().numpy()
        # self.recorder = np.random.choice(fs, size=int(fs.shape[0] * 0.1))

        """
        Q = (q1, q1,...,ql, r) is the coefficient for input.
        X = (x1, x2,...,xl, 1) is the input.
        In node embedding inference, leverage 1-order TS-FIS for each rule.
        In graph level inference, we just consider 0-order TS-FIS Q = (r,).
        """
        student = torch.mean(res[:15000, :], dim=0).detach().cpu().numpy()
        teacher = torch.mean(res[15000:15300, :], dim=0).detach().cpu().numpy()
        logistic = torch.mean(res[15300:15400, :], dim=0).detach().cpu().numpy()
        f = plt.figure()
        c = 1
        plt.rcParams.update({"font.size": 15})

        for identity, value in zip(["student", "teacher", "logistic"], [student, teacher, logistic]):
            ax = f.add_subplot(3, 1, c)
            Rule = []
            Firing_Strength = []
            for i, v in enumerate(value):
                Rule.append(f"R: {i}")
                Firing_Strength.append(v)
                # res["Identity"].append(identity)
            sns.barplot(y=[""] + Rule + [""], x=[0] + Firing_Strength + [0], ax=ax, color='#00C957')
            ax.set_title(identity)
            plt.xlabel("firing strength")
            plt.ylabel("rules")
            c += 1
        plt.show()

        if self.concat:
            res = self.extractor(res)
            if self.order == 1:
                x = torch.cat([res, x], dim=1)
            else:
                x = res
            x = torch.matmul(x, self.coe_defuzzifize)
            # x = self.ban(x)
            # x = self.act(x)
        else:
            # standard defuzzification
            assert self.order in {1, 0}
            x_plus = torch.ones(size=(N, 1)).to(device)
            if self.order == 1:
                x_plus = torch.concatenate([x, x_plus], dim=1)
            x = torch.matmul(x_plus, self.coe_defuzzifize)  # (N, D) @ (D, OUT * R) -> (N, OUT * R)
            x = x.view(N, self.rules_size, self.out_channel)  # (N, OUT * R)  -> (N, R , OUT)
            x = (torch.sum(res.unsqueeze(-1) * x, dim=1))
        return x


class Ruler_A(torch.nn.Module):

    def __init__(self,
                 in_channel: int,
                 out_channel: int,
                 n_mf: int,
                 order: int,
                 window_size: int = 3,
                 stride_size: int = 3,
                 norm: bool = True,
                 concat: bool = True,
                 method: str = "mad",
                 extract_ratio: int = 1,
                 ):

        """
        :param in_channel: input features
        :param out_channel: output features
        :param n_mf: number of membership function
        :param order: the order of defuzzification
        :param norm: normalization defuzzification result
        :param window_size: the slide windows size
        :param stride_size: the slide stride size
        :param concat: concat the firing strength with input
        """

        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.n_mf = n_mf
        self.num_blocks = int(ceil((in_channel - window_size) / stride_size)) + (window_size >= stride_size)
        self.rules_size = self.num_blocks * pow(n_mf, window_size)
        self.padding = (stride_size * (self.num_blocks - 1) + window_size) % in_channel
        self.norm = norm
        self.order = order
        self.method = method
        self.stride = stride_size
        self.window = window_size
        self.concat = concat
        self.extract_ratio = extract_ratio
        self.extractor = None
        self.recorder = [[], []]  # record the firing strength of each rule for each input
        self.select_size = int(ceil(self.rules_size * self.extract_ratio))

        self.reset()

    def reset(self):

        if self.order == 0:
            if not self.concat:
                coe_defuzzifize = torch.nn.init.xavier_normal_(
                    torch.empty(size=(1, self.out_channel * self.rules_size)))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
            else:
                if self.extract_ratio:
                    self.extractor = torch.topk
                coe_defuzzifize = torch.randn(size=(self.rules_size, self.out_channel))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
        elif self.order == 1:
            if not self.concat:
                coe_defuzzifize = torch.randn(size=(self.in_channel + 1, self.out_channel * self.rules_size))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
            else:
                if self.extract_ratio:
                    self.extractor = torch.topk
                coe_defuzzifize = torch.randn(size=(self.in_channel + self.rules_size, self.out_channel))
                self.register_parameter("coe_defuzzifize", torch.nn.Parameter(coe_defuzzifize))
        else:
            raise "order must be either 0 or 1."

        self.coe_defuzzifize.to(device)
        self.recorder = [[], []]

    @staticmethod
    def full_expand(x):

        N, B, D, M = x.shape
        tmp = x[:, :, 0, :]
        for n in range(1, D):
            next_layer = x[:, :, n, :]

            # 10.1109/TFUZZ.2020.2992856
            # next_layer = torch.tile(next_layer.unsqueeze(-2), (tmp.size(1), 1))
            # tmp = torch.tile(tmp.unsqueeze(-1), (1, M))
            # tmp = tmp + next_layer

            tmp = torch.matmul(tmp.unsqueeze(-1), next_layer.unsqueeze(-2))

            # tmp = t_norm((tmp, next_layer), method="alg")

            tmp = tmp.view(N, B, -1)

        # tmp = - 1 / tmp
        return tmp

    def forward(self, x, x_f, edge_index):

        N, D, M = x_f.shape

        # fuzzification
        x_s = x_f[edge_index[0]]
        x_t = x_f[edge_index[1]]

        # calculate firing strength for each rules

        # μ_uv = t(μ(u), μ(v))
        x_s = x_f[edge_index[0]] * x_f[edge_index[1]]


        # mask the singleton which membership degree below threshold
        # mu_ConnG = torch.where(x_s > x_t, x_t, x_s)

        # μ_ui = S(μ_uv), aggregate neighbor information
        degree_ = degree(edge_index[1], num_nodes=N)
        mu_AggregationG = scatter_add(x_s, edge_index[1], dim=0) / degree_[:, None, None]

        # mu_AggregationG = torch.tanh_(scatter_add(x_s, edge_index[1], dim=0))

        # mu_AggregationG = scatter_add(x_s, edge_index[1], dim=0)
        # mu_AggregationG, _ = scatter_max(mu_ConnG, edge_index[1], dim=0)
        # mu_AggregationG, _ = scatter_min(mu_ConnG, edge_index[1], dim=0)

        # μ_u = T(μ_ui), generate firing strength of each rule
        # heatMap(torch.mean(mu_AggregationG, dim=0, keepdim=False))

        # padding aggregation
        mu_AggregationG = torch.concat([mu_AggregationG, torch.zeros(size=(N, self.padding, self.n_mf), device=device)],
                                       dim=-2)

        res = torch.zeros(size=(N, self.num_blocks, self.window, self.n_mf), device=device)
        n = 0
        # mu_AggregationG = torch.log_(mu_AggregationG)
        for block in range(self.num_blocks):
            res[:, block, :, :] = mu_AggregationG[:, n:n + self.window, :]
            n += self.stride
        res = self.full_expand(res)

        # mask low firing strength
        # tmp = torch.where(tmp < 0.05, 0, tmp)

        # normalization factor firing strength
        norm_fac = torch.sum(res, dim=-1, keepdim=True) if self.norm else 1  # (N, R) -> (N, 1)
        res = (res / norm_fac).view(N, -1)

        if self.concat:
            res, index = self.extractor(res, k=self.select_size, dim=1, largest=True)
            coe_defuzzifize = torch.index_select(self.coe_defuzzifize, dim=0, index=index.view(-1)).view(N,
                                                                                                         self.select_size,
                                                                                                         self.out_channel)
            if self.order == 1:
                coe_defuzzifize = torch.cat([coe_defuzzifize, self.coe_defuzzifize[x.shape[1]]])
                x = torch.cat([res, x], dim=1)
            else:
                x = res
            x = torch.mul(x.unsqueeze(dim=-1), coe_defuzzifize).sum(dim=1)
        else:
            # standard defuzzification
            assert self.order in {1, 0}
            x_plus = torch.ones(size=(N, 1)).to(device)
            if self.order == 1:
                x_plus = torch.concatenate([x, x_plus], dim=1)
            x = torch.matmul(x_plus, self.coe_defuzzifize)  # (N, D) @ (D, OUT * R) -> (N, OUT * R)
            x = x.view(N, self.rules_size, self.out_channel)  # (N, OUT * R)  -> (N, R , OUT)

            # res = (res/ norm_fac).view(N, -1).unsqueeze(-1)
            # y = res * x  # (N, R, 1) * (N, R , OUT) -> (N, R , OUT)
            # y = torch.sum(y, dim=1)  # (N, R , OUT) -> (N, 1 , OUT)

            x = torch.sum(res.unsqueeze(-1) * x, dim=1)
        return x


class FLGnnConv(torch.nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 num_mf: int = 3,
                 fix_mf: bool = False,
                 norm: bool = True,
                 method: str = "mad",
                 value_intervals=None,
                 order: int = 1,
                 windows_size: int = 3,
                 stride_size: int = 3,
                 extract_ratio: float = 1,
                 extractor: str = "pool",
                 concat: bool = True,
                 cross: float = 0.7):
        super().__init__()

        self.value_intervals = [0, 1] if value_intervals is None else value_intervals

        self.fuzzier = Fuzzier(in_channels=in_channels,
                               num_mf=num_mf,
                               value_intervals=self.value_intervals,
                               fix=fix_mf,
                               cross=cross)

        self.ruler = Ruler if "pool" in extractor else Ruler_A

        self.ruler = self.ruler(in_channel=in_channels,
                                out_channel=out_channels,
                                n_mf=num_mf,
                                norm=norm,
                                order=order,
                                method=method,
                                window_size=windows_size,
                                stride_size=stride_size,
                                concat=concat,
                                extract_ratio=extract_ratio)

    def reset(self):
        self.batch_norm.weight.data.fill_(1)
        self.batch_norm.bias.data.zero_()
        self.fuzzier.reset()
        self.ruler.reset()

    def forward(self, x, edge_index):
        x = zero_one(x, self.value_intervals)
        x_f = self.fuzzier(x)
        x = self.ruler(x, x_f, edge_index)

        return x


class NodeEmbed(torch.nn.Module):

    def __init__(self, node_features: int, out_feature: int, edge_feature: int = None):
        super().__init__()
        self.nodes = torch.nn.Linear(in_features=node_features, out_features=out_feature)
        self.act = torch.nn.LeakyReLU()
        self.bat = torch.nn.BatchNorm1d(num_features=out_feature)
        self.edges = torch.nn.Linear(in_features=edge_feature, out_features=out_feature) if edge_feature else None

    def forward(self, x, edge_attr=None, edge_index=None, agg="mean"):
        x = self.nodes(x)
        if self.edges is not None:
            e_out = self.edges(edge_attr)
            e_out = scatter(e_out, dim=0, index=edge_index[0], dim_size=x.shape[0], reduce=agg)
            x = torch.cat([x, e_out], dim=1)
        x = self.bat(x)
        x = self.act(x)
        return x


class MolEmbed(torch.nn.Module):

    def __init__(self, hidden: int):
        super().__init__()
        self.atom_encoder = AtomEncoder(emb_dim=hidden // 2)
        self.bond_encoder = BondEncoder(emb_dim=hidden // 2)

    def forward(self, x, edge_attr, edge_index, size):
        edge_attr = scatter(self.bond_encoder(edge_attr), edge_index[0], dim=0, dim_size=size, reduce="mean")
        x = self.atom_encoder(x)
        return torch.cat([x, edge_attr], dim=1)


if __name__ == '__main__':
    pass
