import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import numpy as np
from utils import utils
# from conformal_vae.utils import ptc_utils


class GraphAttentionEdge(Module):
    def __init__(self, feats_in, feats_out, hops, adj, n_heads, bias_flag=True, feats_in_eff=None,
                 feats_out_eff=None, att_output=False):
        super(GraphAttentionEdge, self).__init__()
        self.feats_in = feats_in
        self.feats_out = feats_out
        self.feats_out_eff = feats_out_eff
        self.hops = hops
        self.n_heads = n_heads
        self.att_output = att_output
        #self.sample_idx = sample_idx
        
        gather_idx, scatter_idx = self.create_indices(adj) 
        self.register_buffer('gather_idx', gather_idx)  
        self.register_buffer('scatter_idx', scatter_idx)
        # self.n_pts = torch.max(self.gather_idx) + 1
        self.register_buffer('n_pts', torch.max(self.gather_idx) + 1)

        self.bias_flag = bias_flag

        kernel = torch.zeros([feats_in_eff, feats_out, n_heads], dtype=torch.float32)
        nn.init.xavier_uniform_(kernel)
        kernel = torch.reshape(kernel, [feats_in_eff, feats_out * n_heads])
        self.kernel = nn.Parameter(kernel, requires_grad=True)
        # self.kernel = torch.nn.Conv1d(feats_in, feats_out_eff, 1, bias=bias_flag)
        # self.kernel_edge = torch.nn.Conv1d(feats_in, feats_out_eff, 1, bias=bias_flag)

        if self.att_output:
            att = torch.ones([1, feats_out], dtype=torch.float32)
        else:
            att = torch.ones([n_heads, feats_out], dtype=torch.float32)
        nn.init.xavier_uniform_(att)
        att.unsqueeze_(1)
        self.att = nn.Parameter(att, requires_grad=True)
        # att = torch.ones([self.n_heads, 2 * feats_out_eff // self.n_heads])
        # nn.init.xavier_uniform_(att)
        # att.unsqueeze_(1)
        # self.att = nn.Parameter(att, requires_grad=True)

        if bias_flag:
            bias = torch.zeros(feats_out_eff, 1, dtype=torch.float32)
            nn.init.xavier_uniform_(bias)
            self.bias = Parameter(bias, requires_grad=True)
        else:
            self.register_parameter('bias', None)

    def create_indices(self, adj):
        new_adj = adj.clone()
        for _ in range(self.hops -1):
            temp_adj = torch.mm(new_adj, adj)
            new_adj += adj
        indices = torch.nonzero(new_adj.to_dense(), as_tuple=False).t()
        return indices[0], indices[1]

    def create_sparse_mat(self, mat):
        indices = torch.nonzero(mat).t()
        values = dense[indices[0], indices[1]] # modify this based on dimensionality
        sparse_mat = torch.sparse.FloatTensor(indices, values, mat.size())
        return sparse_mat

    # def forward(self, x):
    #     batch_size, _, n_edge = x.shape
    #
    #     v_signal = torch.zeros([batch_size, self.feats_in, self.n_pts], device=x.device)
    #     v_signal.index_add_(dim=-1, index=self.gather_idx, source=x)
    #     v_signal.index_add_(dim=-1, index=self.scatter_idx, source=x)
    #     x_i = torch.ones_like(x, device=x.device)
    #     v_sum = torch.zeros([batch_size, self.feats_in, self.n_pts], device=x.device)
    #     v_sum.index_add_(dim=-1, index=self.gather_idx, source=x_i)
    #     v_sum.index_add_(dim=-1, index=self.scatter_idx, source=x_i)
    #     v_signal = v_signal / v_sum
    #
    #     h = self.kernel(v_signal)
    #     h = torch.reshape(h, [batch_size, self.n_heads, self.feats_out_eff // self.n_heads, self.n_pts])
    #     h_concat = torch.cat([h[:, :, :, self.gather_idx], h[:, :, :, self.scatter_idx]], dim=2)
    #     e = torch.matmul(self.att, h_concat)  # has shape [batch_size, n_head, 1, n_edge]
    #
    #     x_h = self.kernel_edge(x)
    #     x_h = torch.reshape(x_h, [batch_size, self.n_heads, self.feats_out_eff // self.n_heads, n_edge])
    #
    #     e_new = torch.zeros([batch_size, self.n_heads, 1, self.n_pts], device=x.device)
    #     e_new.index_add_(dim=-1, index=self.gather_idx, source=e)
    #     e_new.index_add_(dim=-1, index=self.scatter_idx, source=e)
    #
    #     h_new = torch.zeros([batch_size, self.n_heads, self.feats_out_eff // self.n_heads, self.n_pts], device=x.device)
    #     h_prime = x_h * e
    #     h_new.index_add_(dim=-1, index=self.gather_idx, source=h_prime)
    #
    #     h_new.index_add_(dim=-1, index=self.scatter_idx, source=h_prime)
    #     h_new = h_new / (e_new + 1E-12)
    #     h_new = torch.reshape(h_new, [batch_size, self.feats_out_eff, self.n_pts])
    #     return h_new

    def forward(self, x):
        # gather_idx = self.gather_idx[:, 0]  # col
        # scatter_idx = self.gather_idx[:, 1]  # row
        # n_pts_out = torch.max(scatter_idx).item() + 1
        z = x.transpose(1, 2)  # has shape NB x N_edges x F_in
        # z = z.unsqueeze(1)
        h = torch.matmul(z, self.kernel)   # has shape NB x N_edges x F_outXn_heads
        h = torch.reshape(h, [h.shape[0], h.shape[1], self.feats_out, self.n_heads]) # has shape NB X NE X F_out x N_Heads
        if self.att_output:
            h = torch.mean(h, dim=-1, keepdim=True)
        # att, n_heads, 1, feats_out
        edge_e = h.permute(0, 3, 2, 1)  # NB X NH X F_out x NE)
        edge_e = torch.matmul(self.att, edge_e) # NB x n_heads x 1 X NE
        edge_e = edge_e.permute(0, 3, 2, 1)  # NB x NE x 1 x NH)
        #edge_e = torch.unsqueeze(edge_e, 2) # NB x NE x 1 x 1 x NH)
        # edge_e = torch.sum(self.att * edge_h, dim=-2, keepdim=True)  # has NB x 2*NE x 1 x n_heads
        edge_e = torch.clamp(edge_e, min=-5E1)  # This is to avoid instability in the next line
        edge_e = torch.exp(- F.leaky_relu(edge_e, negative_slope=0.2))

        e_rowsum = torch.zeros([edge_e.shape[0], self.n_pts, 1, h.shape[-1]], device=edge_e.device, dtype=torch.float32)
        e_rowsum.index_add_(dim=1, index=self.gather_idx, source=edge_e)

        # h_flat = torch.index_select(h_pad, dim=1, index=self.g_idx_pad)  # has NB x 2*NE x F_out
        h_prime = edge_e * h
        h_new = torch.zeros([h_prime.shape[0], self.n_pts, self.feats_out, h_prime.shape[-1]],
                            device=h_prime.device, dtype=torch.float32)
        h_new.index_add_(dim=1, index=self.gather_idx, source=h_prime)

        h_new = h_new / (e_rowsum + 1E-8)
        #h_new = h_new[:, self.sample_idx]

        # h_new = h_new.permute(0, 1, 3, 2)
        h_new = torch.reshape(h_new, [h_new.shape[0], h_new.shape[1], -1])
        h_new = h_new.transpose(1, 2)
        if self.bias_flag:
            h_new = h_new + self.bias
        return h_new


class CalphaBlock(Module):
    def __init__(self, feats_in, feats_out, kernel_size, stride=1, bias_flag=True, padding=0):
        super(CalphaBlock, self).__init__()
        self.feats_in = feats_in
        self.feats_out = feats_out
        self.kernel_size = kernel_size
        self.bias_flag = bias_flag
        self.padding = padding

        self.conv = torch.nn.Conv2d(feats_in, feats_out, (kernel_size, 1), (stride, 1), bias=bias_flag, padding=(padding, 0))
        self.bn = torch.nn.BatchNorm2d(feats_out)
        self.act = torch.nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x


class GraphAttention(Module):
    def __init__(self, feats_in, feats_out, gather_idx, sample_idx, n_heads, bias_flag=True, feats_in_eff=None,
                 feats_out_eff=None, att_output=False):
        super(GraphAttention, self).__init__()
        self.feats_in = feats_in
        self.feats_out = feats_out
        self.n_heads = n_heads
        self.att_output = att_output

        # self.gather_idx = gather_idx
        g_idx = gather_idx[:, 0]
        s_idx = gather_idx[:, 1]
        # self.sample_idx = [torch.tensor(s, device=gather_idx.device) for s in sample_idx]
        # self.register_buffer('sample_idx', [torch.tensor(s, device=gather_idx.device) for s in sample_idx])  #TODO: This buffer is invalid bc its a list
        self.register_buffer('sample_idx0', torch.tensor(sample_idx[0]))
        self.register_buffer('sample_idx1', torch.tensor(sample_idx[1]))
        # self.g_idx_pad = self.sample_idx[0][g_idx]
        # self.s_idx_pad = self.sample_idx[1][s_idx]
        self.register_buffer('g_idx_pad', self.sample_idx0[g_idx])
        self.register_buffer('s_idx_pad', self.sample_idx1[s_idx])

        # self.n_pts_out = torch.max(s_idx).item() + 1
        # self.n_pts_total = max(max(sample_idx[0]), max(sample_idx[1])) + 1
        self.register_buffer('n_pts_out', torch.max(s_idx) + 1)
        self.register_buffer('n_pts_total', torch.tensor(max(max(sample_idx[0]), max(sample_idx[1])) + 1))

        self.bias_flag = bias_flag

        kernel = torch.zeros([feats_in_eff, feats_out, n_heads])
        nn.init.xavier_uniform_(kernel)
        kernel = torch.reshape(kernel, [feats_in_eff, feats_out * n_heads])
        self.kernel = nn.Parameter(kernel, requires_grad=True)

        if self.att_output:
            att = torch.ones([1, 2 * feats_out])
        else:
            att = torch.ones([n_heads, 2 * feats_out])
        nn.init.xavier_uniform_(att)
        att.unsqueeze_(1)
        self.att = nn.Parameter(att, requires_grad=True)

        if bias_flag:
            bias = torch.zeros(feats_out_eff, 1)
            nn.init.xavier_uniform_(bias)
            self.bias = Parameter(bias, requires_grad=True)
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        # gather_idx = self.gather_idx[:, 0]  # col
        # scatter_idx = self.gather_idx[:, 1]  # row
        # n_pts_out = torch.max(scatter_idx).item() + 1
        z = x.transpose(1, 2)  # has shape NB x NPts x F_in
        # z = z.unsqueeze(1)
        h = torch.matmul(z, self.kernel)   # has shape NB x NPts x F_outXn_heads
        h = torch.reshape(h, [h.shape[0], h.shape[1], self.feats_out, self.n_heads]) 
        if self.att_output:
            h = torch.mean(h, dim=-1, keepdim=True)

        h_pad = torch.zeros([h.shape[0], self.n_pts_total, h.shape[2], h.shape[3]], dtype=torch.float32,
                            device=h.device)

        # h_pad.index_add_(dim=1, index=self.sample_idx[0], source=h)
        h_pad.index_add_(dim=1, index=self.sample_idx0, source=h)
        # g_idx_pad = self.sample_idx[0][self.g_idx]
        # s_idx_pad = self.sample_idx[1][self.s_idx]
        # np.where(np.in1d(a,b))[0] where a is larger, so here the scatter
        edge_h = torch.cat([h_pad[:, self.s_idx_pad], h_pad[:, self.g_idx_pad]], dim=-2)  # has NB x n_head x 2*NE x 2*F_out

        # self.att has shape [n_heads, 1, 2*f_out]
        # edge_h has shape   [n_batch, 2*ne, 2*f_out, n_heads]
        #Reshape
        # self.att [n_heads, 1, 2*f_out], #edge_h [n_batch, n_heads, 2*f_out, 2*ne]
        # produces [n_batch, n_heads, 1, 2*ne]
        edge_h = edge_h.transpose(1, 3)
        edge_e = torch.matmul(self.att, edge_h)
        edge_e = edge_e.transpose(1, 3)
        # edge_e = torch.sum(self.att * edge_h, dim=-2, keepdim=True)  # has NB x 2*NE x 1 x n_heads
        edge_e = torch.clamp(edge_e, min=-5E1)  # This is to avoid instability in the next line
        edge_e = torch.exp(- F.leaky_relu(edge_e, negative_slope=0.2))

        e_rowsum = torch.zeros([edge_e.shape[0], self.n_pts_total, 1, h.shape[-1]], device=edge_e.device)
        e_rowsum.index_add_(dim=1, index=self.s_idx_pad, source=edge_e)

        # h_flat = torch.index_select(h_pad, dim=1, index=self.g_idx_pad)  # has NB x 2*NE x F_out
        h_flat = h_pad[:, self.g_idx_pad]
        h_prime = edge_e * h_flat
        h_new = torch.zeros([h_prime.shape[0], self.n_pts_total, self.feats_out, h_prime.shape[-1]],
                            device=h_prime.device)
        h_new.index_add_(dim=1, index=self.s_idx_pad, source=h_prime)

        h_new = h_new / (e_rowsum + 1E-8)
        # h_new = torch.index_select(h_new, dim=1, index=self.sample_idx[1])
        # h_new = h_new[:, self.sample_idx[1]]
        h_new = h_new[:, self.sample_idx1]

        # h_new = h_new.permute(0, 1, 3, 2)
        h_new = torch.reshape(h_new, [h_new.shape[0], h_new.shape[1], -1])
        h_new = h_new.transpose(1, 2) 
        if self.bias_flag:
            h_new = h_new + self.bias
        return h_new
