import torch
import torch.nn as nn
from copy import deepcopy
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_max, scatter
from torch_geometric.utils import softmax, k_hop_subgraph, add_self_loops, degree
from pyro.distributions import RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough
from utils import relabel, negative_sampling, batched_negative_sampling, topk, gumble_topk, sparse_to_dense


class FFN(nn.Module):
    def __init__(self, hid_dim):
        super(FFN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(hid_dim, 2 * hid_dim),
            nn.SiLU(),
            nn.Linear(2 * hid_dim, 2 * hid_dim),
            nn.SiLU(),
            nn.Linear(2 * hid_dim, hid_dim),
            nn.SiLU()
        )
        self.skip = nn.Linear(hid_dim, hid_dim)

    def forward(self, x):
        return self.fc(x) + self.skip(x)


class LogReg(nn.Module):
    def __init__(self, ft_in, nb_classes):
        super(LogReg, self).__init__()
        self.fc = nn.Linear(ft_in, nb_classes)
        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, seq):
        ret = torch.log_softmax(self.fc(seq), dim=-1)
        return ret


class GCN(MessagePassing):
    def __init__(self, emb_dim):
        super(GCN, self).__init__()
        self.linear = nn.Linear(emb_dim, emb_dim)

    def forward(self, x, edge_index, edge_weight=None):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        if edge_weight is not None:
            self_loop_weight = x.new_ones((x.size(0),))
            edge_weight = torch.cat((edge_weight, self_loop_weight), dim=0)
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index=edge_index, x=x, edge_weight=edge_weight, norm=norm)

    def message(self, x_j, edge_weight, norm):
        return norm.view(-1, 1) * x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j


class GIN(MessagePassing):
    def __init__(self, emb_dim):
        super(GIN, self).__init__()
        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, 2 * emb_dim),
            nn.LayerNorm(2 * emb_dim),
            nn.SiLU(),
            nn.Linear(2 * emb_dim, emb_dim), )

    def forward(self, x, edge_index, edge_weight=None):
        return self.propagate(edge_index=edge_index, x=x, edge_weight=edge_weight)

    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def update(self, aggr_out):
        return self.ffn(aggr_out) + aggr_out


class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.bilinear = nn.Bilinear(n_h, n_h, 1)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, g1, g2, h1, h2, h1_corrupt, h2_corrupt, batch1, batch2):
        g1 = torch.repeat_interleave(g1, torch.bincount(batch1), dim=0)
        g2 = torch.repeat_interleave(g2, torch.bincount(batch2), dim=0)

        # positive
        pos_1 = self.bilinear(h2, g1).squeeze()
        pos_2 = self.bilinear(h1, g2).squeeze()

        # negetive
        neg_1 = self.bilinear(h2_corrupt, g1).squeeze()
        neg_2 = self.bilinear(h1_corrupt, g2).squeeze()

        logits = torch.cat((pos_1, pos_2, neg_1, neg_2), 0)
        return logits


class Discriminator1(nn.Module):
    def __init__(self, n_h):
        super(Discriminator1, self).__init__()
        self.bilinear = nn.Bilinear(n_h, n_h, 1)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, c, h1, h2, batch):
        c = torch.repeat_interleave(c, torch.bincount(batch), dim=0)
        pos = self.bilinear(h1, c).squeeze()
        neg = self.bilinear(h2, c).squeeze()
        logits = torch.cat((pos, neg), 0)
        return logits


class NodeEncoder(nn.Module):
    def __init__(self, in_dim, hid_dim, dropout):
        super(NodeEncoder, self).__init__()
        self.linear = nn.Linear(in_dim, hid_dim)
        self.gcn1 = GCN(hid_dim)
        self.gcn2 = GCN(hid_dim)
        self.norm1 = nn.LayerNorm(hid_dim)
        self.norm2 = nn.LayerNorm(hid_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, edge_index, edge_weight, batch, mask):
        h = self.linear(x)
        h = self.gcn1(x=h, edge_index=edge_index, edge_weight=edge_weight)
        h = self.norm1(h)
        h = self.drop(F.silu(h))
        h = self.gcn2(x=h, edge_index=edge_index, edge_weight=edge_weight)
        h = self.norm2(h)
        h = self.drop(F.silu(h))
        if mask is not None:
            g = scatter(h[mask], batch[mask], dim=0, dim_size=batch.max() + 1, reduce='mean')
        else:
            g = scatter(h, batch, dim=0, dim_size=batch.max() + 1, reduce='mean')
        # g = torch.sigmoid(g)
        return h, g


class GNN(nn.Module):
    def __init__(self, in_dim, hid_dim, n_layer, dropout, network='gin', readout='sum', head=True):
        super(GNN, self).__init__()
        self.network = network
        self.head = head
        self.readout = readout
        self.n_layer = n_layer
        self.drop = nn.Dropout(dropout)
        self.linear = nn.Linear(in_dim, hid_dim)
        self.layers = nn.ModuleList([GIN(hid_dim) if network == 'gin' else GCN(hid_dim) for __ in range(n_layer)])
        self.norm = nn.ModuleList([nn.LayerNorm(hid_dim) for __ in range(n_layer)])
        if self.head:
            self.node_head = FFN(hid_dim)
            self.graph_head = FFN(hid_dim)

    def forward(self, x, edge_index, edge_weight, batch, mask):
        h = self.linear(x)
        x = h
        for layer in range(self.n_layer):
            h = self.layers[layer](x=h, edge_index=edge_index, edge_weight=edge_weight)
            h = self.norm[layer](h)
            h = self.drop(F.silu(h))
        h = x + h
        if self.head:
            h = self.node_head(h)
        if mask is not None:
            g = scatter(h[mask], batch[mask], dim=0, dim_size=batch.max() + 1, reduce=self.readout)
        else:
            g = scatter(h, batch, dim=0, dim_size=batch.max() + 1, reduce=self.readout)
        if self.head:
            g = self.graph_head(g)
        return h, g


class IdentityAugment(nn.Module):
    def __init__(self, in_dim, hid_dim, dropout=0.2, random_augment=False, random_drop=0.2):
        super(IdentityAugment, self).__init__()

    def forward(self, batch, h, g, temperature):
        mask = mask = batch.x.new_zeros((batch.x.size(0),), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask

    @torch.no_grad()
    def inference(self, batch, h, g):
        mask = mask = batch.x.new_zeros((batch.x.size(0),), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask


class EdgeAugment(nn.Module):
    def __init__(self, in_dim, hid_dim, random_augment=False, random_drop=0.2):
        super(EdgeAugment, self).__init__()
        self.random_augment = random_augment
        if not self.random_augment:
            self.head = nn.Sequential(
                nn.Linear(hid_dim + 1, 2 * hid_dim),
                nn.LayerNorm(2 * hid_dim),
                nn.SiLU(),
                nn.Linear(2 * hid_dim, 1),)
        else:
            self.drop = nn.Dropout(random_drop)

    def forward(self, batch, h, g, temperature):
        if not self.random_augment:
            pos_edge = batch.edge_index
            try:
                neg_edge = batched_negative_sampling(batch) if len(batch.batch.unique()) > 1 \
                    else negative_sampling(pos_edge, num_neg_samples=batch.edge_index.size(1))
            except:
                print('here')

            # This is to handle fully-connected graphs where there is no negative examples
            if neg_edge is not None:
                __, idx = torch.sort(torch.gather(batch.batch, 0, torch.cat((pos_edge[0], neg_edge[0]), dim=-1)))
                edges = torch.cat((pos_edge, neg_edge), dim=-1)[:, idx]
                connectivity = torch.ones_like(edges[0])
                connectivity[idx >= pos_edge.size(1)] = 0.
            else:
                edges = pos_edge
                connectivity = torch.ones_like(edges[0])

            enc = torch.cat((h[edges[0]] + h[edges[1]], connectivity.view(-1, 1)), dim=-1)
            logits = self.head(enc).squeeze()
            mask = RelaxedBernoulliStraightThrough(temperature, logits=logits).rsample().to(torch.bool)
            # k = (torch.bincount(torch.gather(batch.batch, 0, pos_edge[0])) * 0.9).to(torch.long)
            # mask = gumble_topk(logits / temperature, k, torch.gather(batch.batch, 0, edges[0]))
            batch.edge_index = edges[:, mask]
            batch.edge_weight = logits[mask]
        else:
            mask = torch.ones(batch.edge_index.size(1), dtype=torch.float, device=batch.x.device)
            mask = self.drop(mask).bool()
            batch.edge_index = batch.edge_index[:, mask]

        mask = batch.x.new_zeros((batch.x.size(0),), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask

    @torch.no_grad()
    def inference(self, batch, h, g):
        if not self.random_augment:
            pos_edge = batch.edge_index
            neg_edge = batched_negative_sampling(batch) if len(batch.batch.unique()) > 1 else negative_sampling(batch)
            __, idx = torch.sort(torch.gather(batch.batch, 0, torch.cat((pos_edge[0], neg_edge[0]), dim=-1)))
            edges = torch.cat((pos_edge, neg_edge), dim=-1)[:, idx]
            connectivity = torch.ones_like(edges[0], device=batch.x.device)
            connectivity[idx >= pos_edge.size(1)] = 0.
            enc = torch.cat((h[edges[0]] + h[edges[1]], connectivity.view(-1, 1)), dim=-1)
            logits = self.head(enc).squeeze()
            mask = topk(logits, 0.5, torch.gather(batch.batch, 0, edges[0]))
            batch.edge_index = edges[:, mask]
            batch.edge_weight = logits[mask]
        else:
            mask = torch.ones(batch.edge_index.size(1), dtype=torch.bool, device=batch.x.device)
            mask = self.drop(mask)
            batch.edge_index = batch.edge_index[:, mask]
        mask = batch.x.new_zeros((batch.x.size(0),), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask


class NodeAugment(nn.Module):
    def __init__(self, in_dim, hid_dim, ratio, random_augment=False, random_drop=0.2):
        super(NodeAugment, self).__init__()
        self.ratio = ratio
        self.random_augment = random_augment
        if not self.random_augment:
            self.head = nn.Sequential(
                nn.Linear(hid_dim, 2 * hid_dim),
                nn.LayerNorm(2 * hid_dim),
                nn.SELU(),
                nn.Linear(2 * hid_dim, 1),)
        else:
            self.drop = nn.Dropout(random_drop)

    def forward(self, batch, h, g, temperature):
        if not self.random_augment:
            n_nodes = torch.bincount(batch.batch)
            logits = self.head(h + torch.repeat_interleave(g, n_nodes, dim=0)).squeeze()
            # mask = RelaxedBernoulliStraightThrough(temperature, logits=logits).rsample().to(torch.bool)
            mask = gumble_topk(logits / temperature, self.ratio, batch.batch)
            mask = torch.gather(mask, 0, batch.edge_index[0]) * torch.gather(mask, 0, batch.edge_index[1])
            edge_weight = torch.gather(logits, 0, batch.edge_index[0]) + torch.gather(logits, 0, batch.edge_index[1])
            batch.edge_index = batch.edge_index[:, mask.bool()]
            batch.edge_weight = edge_weight[mask.bool()]
        else:
            b = deepcopy(batch)
            mask = torch.ones(batch.x.size(0), device=batch.x.device)
            c = torch.cat([batch.batch.new_zeros(1), torch.bincount(batch.batch).cumsum(dim=0)], -1)
            mask = torch.cat([self.drop(mask[c[i]:c[i+1]]) for i in range(len(c)-1)]).bool()
            mask = torch.gather(mask, 0, batch.edge_index[0]) * torch.gather(mask, 0, batch.edge_index[1])
            batch.edge_index = batch.edge_index[:, mask]
        mask = batch.x.new_zeros((batch.x.size(0), ), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask

    @torch.no_grad()
    def inference(self, batch, h, g):
        if not self.random_augment:
            n_nodes = torch.bincount(batch.batch)
            logits = self.head(h + torch.repeat_interleave(g, n_nodes, dim=0)).squeeze()
            mask = topk(logits, 0.9, batch.batch)
            mask = torch.gather(mask, 0, batch.edge_index[0]) * torch.gather(mask, 0, batch.edge_index[1])
            edge_weight = torch.gather(logits, 0, batch.edge_index[0]) + torch.gather(logits, 0, batch.edge_index[1])
            batch.edge_index = batch.edge_index[:, mask.bool()]
            batch.edge_weight = edge_weight[mask.bool()]
        else:
            mask = torch.ones(batch.x.size(0), device=batch.x.device)
            c = torch.cat([batch.batch.new_zeros(1), torch.bincount(batch.batch).cumsum(dim=0)], -1)
            mask = torch.cat([self.drop(mask[c[i]:c[i+1]]) for i in range(len(c)-1)]).bool()
            mask = torch.gather(mask, 0, batch.edge_index[0]) * torch.gather(mask, 0, batch.edge_index[1])
            batch.edge_index = batch.edge_index[:, mask]
        mask = batch.x.new_zeros((batch.x.size(0), ), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask


class GraphAugment(nn.Module):
    def __init__(self, in_dim, hid_dim, n_hops, random_augment=False, random_drop=0.2):
        super(GraphAugment, self).__init__()
        self.n_hops = n_hops
        self.random_augment = random_augment
        if not self.random_augment:
            self.head = nn.Sequential(
                nn.Linear(hid_dim, 2 * hid_dim),
                nn.LayerNorm(2 * hid_dim),
                nn.SELU(),
                nn.Linear(2 * hid_dim, 1),)

    def forward(self, batch, h, g, temperature):
        if not self.random_augment:
            n_nodes = torch.bincount(batch.batch)
            logits = self.head(h + torch.repeat_interleave(g, n_nodes, dim=0)).squeeze()
            node_prob = softmax(logits, batch.batch).squeeze()
            node_prob, index = sparse_to_dense(node_prob, batch.batch, pad=.0)
            node_sample = RelaxedOneHotCategoricalStraightThrough(temperature, probs=node_prob).rsample()
            node_sample = torch.gather(node_sample.flatten(), 0, index)
            __, __, __, mask = k_hop_subgraph(torch.nonzero(node_sample).flatten(), self.n_hops, batch.edge_index)
            edge_weight = torch.gather(logits, 0, batch.edge_index[0]) + torch.gather(logits, 0, batch.edge_index[1])
            batch.edge_index = batch.edge_index[:, mask]
            batch.edge_weight = edge_weight[mask]
        else:
            n_nodes = torch.bincount(batch.batch)
            c_nodes = torch.cat((torch.zeros(1, dtype=torch.long, device=batch.x.device), n_nodes.cumsum(dim=0)), dim=-1)
            idx = [(c_nodes[i] + torch.randint(0, n, (1,), device=batch.x.device)).item() for i, n in enumerate(n_nodes)]
            __, __, __, mask = k_hop_subgraph(idx, self.n_hops, batch.edge_index)
            batch.edge_index = batch.edge_index[:, mask]
        mask = batch.x.new_zeros((batch.x.size(0), ), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask

    @torch.no_grad()
    def inference(self, batch, h, g, n_hops=4):
        _, batch = relabel(batch)
        if not self.random_augment:
            n_nodes = torch.bincount(batch.batch)
            logits = self.head(h + torch.repeat_interleave(g, n_nodes, dim=0)).flatten()
            __, node_sample = scatter_max(logits, batch.batch)
            __, __, __, mask = k_hop_subgraph(node_sample, n_hops, batch.edge_index)
            edge_weight = torch.gather(logits, 0, batch.edge_index[0]) + torch.gather(logits, 0, batch.edge_index[1])
            batch.edge_index = batch.edge_index[:, mask]
            batch.edge_weight = edge_weight[mask]
        else:
            n_nodes = torch.bincount(batch.batch)
            c_nodes = torch.cat((torch.zeros(1, dtype=torch.long, device=n_nodes.device), n_nodes.cumsum(dim=0)), dim=-1)
            idx = [(c_nodes[i] + torch.randint(0, n, (1,), device=batch.x.device)).item() for i, n in enumerate(n_nodes)]
            __, __, __, mask = k_hop_subgraph(idx, n_hops, batch.edge_index)
            batch.edge_index = batch.edge_index[:, mask]
        mask = batch.x.new_zeros((batch.x.size(0),), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask


class FeatureAugment(nn.Module):
    def __init__(self, in_dim, hid_dim, random_augment=False, random_drop=0.2):
        super(FeatureAugment, self).__init__()
        self.random_augment = random_augment
        if not self.random_augment:
            self.linear = nn.Linear(in_dim, in_dim)
            self.head = nn.Sequential(
                nn.Linear(hid_dim, 2 * hid_dim),
                nn.LayerNorm(2 * hid_dim),
                nn.SELU(),
                nn.Linear(2 * hid_dim, in_dim),)
        else:
            self.drop = nn.Dropout(random_drop)

    def forward(self, batch, h, g, temperature):
        if not self.random_augment:
            batch.x = self.linear(batch.x)
            logits = self.head(h)
            mask = RelaxedBernoulliStraightThrough(temperature=temperature, logits=logits).rsample()
            batch.x = batch.x * mask
        else:
            batch.x = self.drop(batch.x)
        mask = batch.x.new_zeros((batch.x.size(0),), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask

    @torch.no_grad()
    def inference(self, batch, h, g):
        if not self.random_augment:
            batch.x = self.linear(batch.x)
            logits = self.head(h)
            mask = torch.sigmoid(logits).round()
            batch.x = batch.x * mask
        mask = mask = batch.x.new_zeros((batch.x.size(0),), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask


class Policy(nn.Module):
    def __init__(self, hid_dim, dropout, n_aug=5, aggregation='gru'):
        super(Policy, self).__init__()
        self.n_aug = n_aug
        self.aggregation = aggregation
        if aggregation == 'gru':
            self.gru = nn.GRU(hid_dim, hid_dim, batch_first=True)
        elif aggregation == 'set':
            self.projection1 = nn.Sequential(
                nn.Linear(hid_dim, hid_dim),
                nn.ELU(),
            )
            self.projection2 = nn.Sequential(
                nn.Linear(hid_dim, hid_dim),
                nn.ELU(),
            )
        self.head = nn.Sequential(
            nn.Linear(hid_dim, 2 * hid_dim),
            nn.LayerNorm(2 * hid_dim),
            nn.SELU(),
            nn.Linear(2 * hid_dim, self.n_aug),)

    def forward(self, g, temperature, num_sample):
        if self.aggregation == 'gru':
            __, idx = torch.sort(g.sum(1), descending=True)
            __, g = self.gru(g[idx].unsqueeze(0))
        elif self.aggregation == 'set':
            g = self.projection1(g)
            g = g.sum(0)
            g = self.projection2(g)
        p = self.head(g.squeeze(0))
        a = RelaxedOneHotCategoricalStraightThrough(temperature, logits=p).rsample((num_sample,)).view(-1, self.n_aug)
        return a, F.softmax(p, -1)


class SubGraphSampler(nn.Module):
    def __init__(self, in_dim, hid_dim, random_augment=False, random_drop=0.2):
        super(SubGraphSampler, self).__init__()
        self.random_augment = random_augment
        if not self.random_augment:
            self.head = nn.Sequential(
                nn.Linear(hid_dim, 2 * hid_dim),
                nn.LayerNorm(2 * hid_dim),
                nn.SELU(),
                nn.Linear(2 * hid_dim, 1),)

    def forward(self, batch, h, g, temperature, n_samples, n_hops=4):
        if not self.random_augment:
            logits = self.head(h).squeeze()
            mask = gumble_topk(logits/temperature, n_samples, None)
            __, __, __, mask = k_hop_subgraph(torch.nonzero(mask).flatten(), n_hops, batch.edge_index)
            edge_weight = torch.gather(logits, 0, batch.edge_index[0]) + torch.gather(logits, 0, batch.edge_index[1])
            batch.edge_index = batch.edge_index[:, mask]
            batch.edge_weight = edge_weight[mask]
        else:
            idx = torch.randint(0, batch.x.size(0), (n_samples,), device=batch.x.device)
            __, __, __, mask = k_hop_subgraph(idx, n_hops, batch.edge_index)
            batch.edge_index = batch.edge_index[:, mask]
        mask = batch.x.new_zeros((batch.x.size(0), ), dtype=torch.bool)
        mask[batch.edge_index.flatten()] = True
        return batch, mask


class Augmentations(nn.Module):
    def __init__(self, feat_dim, hid_dim, ratio, random):
        super(Augmentations, self).__init__()
        self.augmentations = nn.ModuleList([
            EdgeAugment(feat_dim, hid_dim, random_augment=random),
            NodeAugment(feat_dim, hid_dim, ratio, random_augment=random),
            FeatureAugment(feat_dim, hid_dim, random_augment=random),
            IdentityAugment(feat_dim, hid_dim, random_augment=random)])

