import math

import numpy as np
import torch
from torch_geometric.nn.inits import zeros
from torch import nn, optim
import pytorch_lightning as pl
from torch_sparse import SparseTensor
import torch.nn.functional as F
from torch.nn import ModuleList, Linear
from torch_geometric.nn import (
    SAGEConv,
    GCNConv,
    GATConv,
    JumpingKnowledge,
)

import DA_SAGE
from datasets.data_utils import get_norm_adj
import DA_GCN
import torch_geometric as pyg

from dag_utils import init_graph, split_graph, directed_norm_adj


def get_conv(conv_type, input_dim, output_dim, alpha, ordering_type):
    if conv_type == "gcn":
        return GCNConv(input_dim, output_dim, add_self_loops=False)
    elif conv_type == "sage":
        return SAGEConv(input_dim, output_dim)
    elif conv_type == "gat":
        return GATConv(input_dim, output_dim, heads=1)
    elif conv_type == "dir-gcn":
        return DirGCNConv(input_dim, output_dim, alpha)
    elif conv_type == "dir-sage":
        return DirSageConv(input_dim, output_dim, alpha)
    elif conv_type == "dir-gat":
        return DirGATConv(input_dim, output_dim, heads=1, alpha=alpha)
    elif conv_type == "da-sage":
        return DaDirSageConv(input_dim, output_dim, alpha)
    elif conv_type == "da-gcn":
        return DaDirGCNConv(input_dim, output_dim, alpha)
    else:
        raise ValueError(f"Convolution type {conv_type} not supported")


class DaDirSageConv(torch.nn.Module):
    def __init__(self, input_dim, output_dim, alpha):
        super(DaDirSageConv, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.conv_src_to_dst = DA_SAGE.DASAGEConv(input_dim, output_dim, share_init=True, cached=True, root_weight=False)
        self.conv_dst_to_src = DA_SAGE.DASAGEConv(input_dim, output_dim, share_init=True, cached=True, root_weight=False, degree_index=1)
        self.lin_self = Linear(input_dim, output_dim)
        self.alpha = alpha

    def forward(self, x, edge_index):
        x_ = None
        if self.alpha > 0:
            edge_index_t = torch.stack([edge_index[1], edge_index[0]], dim=0)
            x_ = self.alpha * self.conv_dst_to_src(x, edge_index_t)
        if self.alpha < 1:
            if x_ is None:
                x_ = (1 - self.alpha) * self.conv_src_to_dst(x, edge_index)
            else:
                x_ += (1 - self.alpha) * self.conv_src_to_dst(x, edge_index)
        return self.lin_self(x) + x_


class DaDirGCNConv(torch.nn.Module):
    def __init__(self, input_dim, output_dim, alpha):
        super(DaDirGCNConv, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.conv_src_to_dst = DA_GCN.DAGCNConv(input_dim, output_dim, alpha=1.0, share_init=True, add_self_loops=False, norm='dir',
                                cached=True)
        self.conv_dst_to_src = DA_GCN.DAGCNConv(input_dim, output_dim, alpha=1.0, share_init=True, add_self_loops=False, norm='dir',
                                cached=True, degree_index=1)
        self.alpha = alpha

    def forward(self, x, edge_index):
        edge_index_t = torch.stack([edge_index[1], edge_index[0]], dim=0)

        return (1 - self.alpha) * self.conv_src_to_dst(x, edge_index) + self.alpha * self.conv_dst_to_src(
            x, edge_index_t
        )

class DirGCNConv(torch.nn.Module):
    def __init__(self, input_dim, output_dim, alpha):
        super(DirGCNConv, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.lin_src_to_dst = Linear(input_dim, output_dim)
        self.lin_dst_to_src = Linear(input_dim, output_dim)
        self.alpha = alpha
        self.adj_norm, self.adj_t_norm = None, None

    def forward(self, x, edge_index):
        if self.adj_norm is None:
            row, col = edge_index
            num_nodes = x.shape[0]
            adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes))
            self.adj_norm = get_norm_adj(adj, norm="dir")

            adj_t = SparseTensor(row=col, col=row, sparse_sizes=(num_nodes, num_nodes))
            self.adj_t_norm = get_norm_adj(adj_t, norm="dir")

        x_ = None
        if self.alpha > 0:
            x_ = self.alpha * self.lin_src_to_dst(self.adj_norm @ x)
        if self.alpha < 1:
            if x_ is None:
                x_ = (1 - self.alpha) * self.lin_dst_to_src(self.adj_t_norm @ x)
            else:
                x_ += (1 - self.alpha) * self.lin_dst_to_src(self.adj_t_norm @ x)
        return x_





class DirSageConv(torch.nn.Module):
    def __init__(self, input_dim, output_dim, alpha):
        super(DirSageConv, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.conv_src_to_dst = SAGEConv(input_dim, output_dim, flow="source_to_target", root_weight=False)
        self.conv_dst_to_src = SAGEConv(input_dim, output_dim, flow="target_to_source", root_weight=False)
        self.lin_self = Linear(input_dim, output_dim)
        self.alpha = alpha

    def forward(self, x, edge_index):
        x_ = None
        if self.alpha > 0:
            x_ = self.alpha * self.conv_dst_to_src(x, edge_index)
        if self.alpha < 1:
            if x_ is None:
                x_ = (1 - self.alpha) * self.conv_src_to_dst(x, edge_index)
            else:
                x_ += (1 - self.alpha) * self.conv_src_to_dst(x, edge_index)
        return self.lin_self(x) + x_


class DirGATConv(torch.nn.Module):
    def __init__(self, input_dim, output_dim, heads, alpha):
        super(DirGATConv, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.conv_src_to_dst = GATConv(input_dim, output_dim, heads=heads)
        self.conv_dst_to_src = GATConv(input_dim, output_dim, heads=heads)
        self.alpha = alpha

    def forward(self, x, edge_index):
        edge_index_t = torch.stack([edge_index[1], edge_index[0]], dim=0)

        return (1 - self.alpha) * self.conv_src_to_dst(x, edge_index) + self.alpha * self.conv_dst_to_src(
            x, edge_index_t
        )


class GNN(torch.nn.Module):
    def __init__(
            self,
            num_features,
            num_classes,
            hidden_dim,
            num_layers=2,
            dropout=0,
            conv_type="dir-gcn",
            jumping_knowledge=False,
            normalize=False,
            alpha=1 / 2,
            learn_alpha=False,
            ordering_type='degree',
    ):
        super(GNN, self).__init__()

        self.alpha = nn.Parameter(torch.ones(1) * alpha, requires_grad=learn_alpha)
        output_dim = hidden_dim if jumping_knowledge else num_classes
        if num_layers == 1:
            self.convs = ModuleList([get_conv(conv_type, num_features, output_dim, self.alpha, ordering_type)])
        else:
            self.convs = ModuleList([get_conv(conv_type, num_features, hidden_dim, self.alpha, ordering_type)])
            for _ in range(num_layers - 2):
                self.convs.append(get_conv(conv_type, hidden_dim, hidden_dim, self.alpha, ordering_type))
            self.convs.append(get_conv(conv_type, hidden_dim, output_dim, self.alpha, ordering_type))

        if jumping_knowledge is not None:
            input_dim = hidden_dim * num_layers if jumping_knowledge == "cat" else hidden_dim
            self.lin = Linear(input_dim, num_classes)
            self.jump = JumpingKnowledge(mode=jumping_knowledge, channels=hidden_dim, num_layers=num_layers)

        self.num_layers = num_layers
        self.dropout = dropout
        self.jumping_knowledge = jumping_knowledge
        self.normalize = normalize
        self.conv_type = conv_type

    def forward(self, x, edge_index):
        xs = []
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i != len(self.convs) - 1 or self.jumping_knowledge:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
                if self.normalize:
                    x = F.normalize(x, p=2, dim=1)
            xs += [x]

        if self.jumping_knowledge is not None:
            x = self.jump(xs)
            x = self.lin(x)

        return torch.nn.functional.log_softmax(x, dim=1)


class LightingFullBatchModelWrapper(pl.LightningModule):
    def __init__(self, model, lr, weight_decay, train_mask, val_mask, test_mask, evaluator=None):
        super().__init__()
        self.model = model
        self.lr = lr
        self.weight_decay = weight_decay
        self.evaluator = evaluator
        self.train_mask, self.val_mask, self.test_mask = train_mask, val_mask, test_mask

    def training_step(self, batch, batch_idx):
        x, y, edge_index = batch.x, batch.y.long(), batch.edge_index
        out = self.model(x, edge_index)

        loss = nn.functional.nll_loss(out[self.train_mask], y[self.train_mask].squeeze())
        self.log("train_loss", loss)

        y_pred = out.max(1)[1]
        train_acc = self.evaluate(y_pred=y_pred[self.train_mask], y_true=y[self.train_mask])
        self.log("train_acc", train_acc, prog_bar=True)
        # val_acc = self.evaluate(y_pred=y_pred[self.val_mask], y_true=y[self.val_mask])
        # self.log("val_acc", val_acc)

        return loss

    def evaluate(self, y_pred, y_true):
        if self.evaluator:
            acc = self.evaluator.eval({"y_true": y_true, "y_pred": y_pred.unsqueeze(1)})["acc"]
        else:
            acc = y_pred.eq(y_true.squeeze()).sum().item() / y_pred.shape[0]

        return acc

    def validation_step(self, batch, batch_idx):
        out = self.model(batch.x, batch.edge_index)
        y_pred = out.max(1)[1]
        val_acc = self.evaluate(y_pred=y_pred[self.val_mask], y_true=batch.y[self.val_mask])
        val_loss = nn.functional.nll_loss(out[self.val_mask], batch.y[self.val_mask].squeeze())
        self.log("val_acc", val_acc, prog_bar=True)
        self.log("val_loss", val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        x, y, edge_index = batch.x, batch.y.long(), batch.edge_index
        out = self.model(x, edge_index)

        y_pred = out.max(1)[1]
        val_acc = self.evaluate(y_pred=y_pred[self.test_mask], y_true=y[self.test_mask])
        self.log("test_acc", val_acc)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return optimizer


def get_model(args):
    return GNN(
        num_features=args.num_features,
        hidden_dim=args.hidden_dim,
        num_layers=args.num_layers,
        num_classes=args.num_classes,
        dropout=args.dropout,
        conv_type=args.conv_type,
        jumping_knowledge=args.jk,
        normalize=args.normalize,
        alpha=args.alpha,
        learn_alpha=args.learn_alpha,
        ordering_type=args.ordering_type,
    )
