import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F
from parser_1 import _parser
from torch.nn.parameter import Parameter

args = _parser()

class GraphAttentionNetwork(nn.Module):
    """
    GAT layer
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionNetwork, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(nn.init.xavier_normal_(torch.Tensor(in_features, out_features).type(
            torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)),
            requires_grad=True)
        self.a1 = nn.Parameter(nn.init.xavier_normal_(torch.Tensor(out_features, 1).type(
            torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)),
            requires_grad=True)
        self.a2 = nn.Parameter(nn.init.xavier_normal_(torch.Tensor(out_features, 1).type(
            torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)),
            requires_grad=True)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        batch = input.shape[0]
        h_prime = []
        for i in range(batch):
            h = torch.mm(input[i], self.W)
            N = h.size()

            f_1 = torch.matmul(h, self.a1)
            f_2 = torch.matmul(h, self.a2)
            e = self.leakyrelu(f_1 + f_2.transpose(0, 1))

            zero_vec = -9e15 * torch.ones_like(e)
            attention = torch.where(adj[i] > 0, e, zero_vec)
            attention = F.softmax(attention, dim=1)
            attention = F.dropout(attention, self.dropout, training=self.training)
            out = torch.matmul(attention, h)
            h_prime.append(out.data.numpy())

        h_prime = torch.from_numpy(np.array(h_prime))

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


# GraphConvolution layers and models
class GraphConvNew(nn.Module):
    """
    Graph Convolution Layer & Additional tricks (power of adjacency matrix and weighted self connections)
    n_relations: number of relation types (adjacency matrices)
    """

    def __init__(self, in_features, out_features, n_relations=1,
                 activation=None, adj_sq=False, scale_identity=False):
        super(GraphConvNew, self).__init__()
        self.fc = nn.Linear(in_features=in_features * n_relations, out_features=out_features)
        self.n_relations = n_relations
        self.activation = activation
        self.adj_sq = adj_sq
        self.scale_identity = scale_identity

    def laplacian_batch(self, A):
        batch, N = A.shape[:2]
        if self.adj_sq:
            A = torch.bmm(A, A)  # use A^2 to increase graph connectivity
        I = torch.eye(N).unsqueeze(0).to(args.device)
        a = I.shape
        if self.scale_identity:
            I = 2 * I  # increase weight of self connections
        # add I represents self connections of nodes
        A_hat = I + A
        L = A_hat  # remove D_hat
        return L

    def forward(self, data):
        x, A, mask = data[:3]
        if len(A.shape) == 3:
            A = A.unsqueeze(3)
        x_hat = []
        for rel in range(self.n_relations):
            x_hat.append(torch.bmm(self.laplacian_batch(A[:, :, :, rel]), x))
        x = self.fc(torch.cat(x_hat, 2))

        if len(mask.shape) == 2:
            mask = mask.unsqueeze(2)
        x = x * mask
        # to make values of dummy nodes zeros again, otherwise the bias is added after applying self.fc
        # which affects node embeddings in the following layers
        if self.activation is not None:
            x = self.activation(x)
        return x, A, mask


class GraphConvNew1(nn.Module):
    """
    Graph Convolution Layer with Additional tricks (power of adjacency matrix and weighted self connections)

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        n_relations (int): Number of relation types (adjacency matrices).
        activation (callable): Activation function.
        adj_sq (bool): Whether to use the square of the adjacency matrix.
        scale_identity (bool): Whether to scale the identity matrix.
    """

    def __init__(self, in_features, out_features, n_relations=1,
                 activation=None, adj_sq=False, scale_identity=False):
        super(GraphConvNew, self).__init__()
        self.fc = nn.Linear(in_features * n_relations, out_features)
        self.n_relations = n_relations
        self.activation = activation
        self.adj_sq = adj_sq
        self.scale_identity = scale_identity

    def laplacian_batch(self, A):
        """
        Compute the modified Laplacian for a batch of adjacency matrices.

        Args:
            A (torch.Tensor): Adjacency matrices, shape (batch_size, N, N, n_relations).

        Returns:
            torch.Tensor: Modified Laplacian, shape (batch_size, N, N, n_relations).
        """
        batch_size, N, _, n_relations = A.size()

        if self.adj_sq:
            A = torch.matmul(A, A)  # Use A^2 to increase graph connectivity

        I = torch.eye(N).to(A.device).unsqueeze(0).unsqueeze(-1)  # Shape (1, N, N, 1)

        if self.scale_identity:
            I = 2 * I  # Increase weight of self connections

        A_hat = A + I  # Add self connections
        return A_hat

    def forward(self, data):
        """
        Forward pass for the graph convolution layer.

        Args:
            data (tuple): Tuple containing node features, adjacency matrices, and mask.

        Returns:
            tuple: Updated node features, adjacency matrices, and mask.
        """
        x, A, mask = data[:3]

        if A.dim() == 3:
            A = A.unsqueeze(-1)  # Shape (batch_size, N, N, 1)

        batch_size, N, _, n_relations = A.size()

        A_hat = self.laplacian_batch(A)  # Compute modified Laplacian

        x_hat = [torch.matmul(A_hat[..., rel], x) for rel in range(n_relations)]
        x = torch.cat(x_hat, dim=-1)  # Concatenate features from different relations
        x = self.fc(x)

        if mask.dim() == 2:
            mask = mask.unsqueeze(-1)  # Shape (batch_size, N, 1)

        x = x * mask  # Apply mask to node features

        if self.activation is not None:
            x = self.activation(x)

        return x, A, mask

class GraphConvolutionNetwork(nn.Module):
    """
    Simple GCN layer
    """
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolutionNetwork, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        batch = input.shape[0]
        output = []
        for i in range(batch):
            if len(input.shape) == 3:
                support = torch.mm(input[i], self.weight)
                out = torch.spmm(adj[i], support)
                output.append(out.data.numpy())
            else:
                support = torch.mm(input[i], self.weight)
                out = torch.spmm(adj[i], support[i])
                output.append(out.data.numpy())
        # output = torch.mean(output, dim=0, keepdim=True)
        output = torch.from_numpy(np.array(output))
        if self.bias is not None:
            output = output + self.bias
            return output
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

