import torch
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import GCNConv, GATv2Conv, GraphSAGE, GATConv, GIN, TransformerConv
from utils import get_parameter_number
from torch.cuda.amp import autocast as autocast, GradScaler
import numpy as np
import tqdm
from tqdm import tqdm
from CiteSeer_loader import citeseer

"GIN, Trans, GATv2, GAT, GCN"

device = "cuda"
from sklearn.metrics import f1_score
from ogb.nodeproppred.evaluate import Evaluator


class GraphSage(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GS = GraphSAGE(in_channels=in_channels, out_channels=hidden_feature, hidden_channels=hidden_feature,
                            num_layers=3, dropout=0.5)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        # x = self.node_embed(x, edge_attr, edge_index)
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
        x = self.GS(x, edge_index, edge_attr=edge_attr)
        return x


class GAT(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GAT_1 = GATConv(in_channels=in_channels, out_channels=hidden_feature, concat=True, heads=6, dropout=0.6)
        self.GAT_2 = GATConv(in_channels=6 * hidden_feature, out_channels=hidden_feature, heads=6, dropout=0.6, concat=True)
        self.GAT_3 = GATConv(in_channels=6 * hidden_feature, out_channels=out_channels, heads=6, dropout=0.6, concat=False)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        x = torch.nn.functional.elu(self.GAT_1(x, edge_index, edge_attr))
        x = torch.nn.functional.dropout(x, p=0.6, training=self.training)
        x = torch.nn.functional.elu(self.GAT_2(x, edge_index, edge_attr))
        x = torch.nn.functional.dropout(x, p=0.6, training=self.training)
        x = self.GAT_3(x, edge_index, edge_attr)
        return x


class GATv2(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GATv2_1 = GATv2Conv(in_channels=in_channels, out_channels=hidden_feature, concat=True, heads=6,
                                 edge_dim=hidden_feature)
        self.GATv2_2 = GATv2Conv(in_channels=6 * hidden_feature, out_channels=hidden_feature, concat=True, heads=6,
                                 edge_dim=hidden_feature)
        self.GATv2_3 = GATv2Conv(in_channels=6 * hidden_feature, out_channels=out_channels, concat=False, heads=6,
                                 edge_dim=in_channels)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        x = torch.nn.functional.dropout(x, training=self.training)
        x = self.GATv2_1(x, edge_index, edge_attr=edge_attr)
        x = torch.relu_(x)
        x = torch.nn.functional.dropout(x, training=self.training)
        x = self.GATv2_2(x, edge_index, edge_attr=edge_attr)
        x = torch.nn.functional.dropout(x, training=self.training)
        x = torch.relu_(x)
        x = self.GATv2_2(x, edge_index, edge_attr=edge_attr)
        return x


class GCN(torch.nn.Module):

    def __init__(self, in_channels: int, hidden_feature: int, out_channels: int, **kwargs):
        super().__init__()

        self.Gcn1 = GCNConv(in_channels=in_channels, out_channels=hidden_feature, cached=True)
        self.Gcn2 = GCNConv(in_channels=hidden_feature, out_channels=hidden_feature, cached=True)
        self.Gcn3 = GCNConv(in_channels=hidden_feature, out_channels=out_channels, cached=True)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.size(0))
        x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
        x = self.Gcn1(x, edge_index)
        x = torch.relu_(x)
        x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
        # x = self.Gcn2(x, edge_index)
        # x = torch.relu_(x)
        x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
        x = self.Gcn3(x, edge_index)
        return x


class TranformerGNN(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.transGnn1 = TransformerConv(in_channels=in_channels, out_channels=hidden_feature, concat=True, heads=6,
                                         dropout=0.5)
        self.transGnn2 = TransformerConv(in_channels=6 * hidden_feature, out_channels=hidden_feature, concat=True,
                                         heads=6, dropout=0.5)
        self.transGnn3 = TransformerConv(in_channels=6 * hidden_feature, out_channels=out_channels, concat=False,
                                         heads=6, dropout=0.5)
        self.act = torch.nn.ReLU()

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        x = torch.nn.functional.dropout(x, training=self.training)
        x = self.transGnn1(x, edge_index, edge_attr)
        x = self.act(x)
        x = torch.nn.functional.dropout(x, training=self.training)
        x = self.transGnn2(x, edge_index, edge_attr)
        x = self.act(x)
        x = torch.nn.functional.dropout(x, training=self.training)
        x = self.transGnn3(x, edge_index, edge_attr)
        return x


class GIN_(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GNN = GIN(hidden_channels=hidden_feature, in_channels=in_channels, out_channels=hidden_feature,
                       num_layers=1,
                       dropout=0.5)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.size(0))
        x = self.GNN(x, edge_index)
        x = self.lin(x)
        return x


def f1(true, pre):
    true = true.detach().cpu().numpy()
    pre = pre.detach().cpu().numpy()
    return f1_score(true, pre, average="micro")


def metric(m, dataset):
    m.eval()
    with torch.no_grad():
        pre_y = m(dataset).argmax(dim=1)
        f1_test = f1(dataset.y[dataset.test_mask], pre_y[dataset.test_mask])
        # f1_train = f1(dataset.y[dataset.train_mask], pre_y[dataset.train_mask])
    m.train()
    return f1_test


def metric_ogb(m, dataset):
    m.eval()
    evaluator = Evaluator("ogbn-arxiv")
    with torch.no_grad():
        pre_y = m(dataset).argmax(dim=-1, keepdim=True)
        f1_test = evaluator.eval({"y_true": dataset.y[dataset.test_mask], "y_pred": pre_y[dataset.test_mask]})
        # f1_train = evaluator.eval({"y_true": dataset.y[dataset.train_mask], "y_pred": pre_y[dataset.train_mask]})
        # print(f"f1_test: {f1_test},  f1_train: {f1_train}")
    m.train()
    return f1_test


def fit(model, config, dataset):
    epoch = config.get("epoch", 10)

    optim_config = config.get("optim", {"lr": 0.05, "weight_decay": 5e-4})

    critical = {
        "regress": torch.nn.MSELoss,
        "binary_classify": torch.nn.BCEWithLogitsLoss,
        "multi_classify": torch.nn.CrossEntropyLoss,
    }[config.get("type")]()

    optimizer = torch.optim.Adam(**optim_config, params=m.parameters())

    reduce_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=0.001)
    min_loss = None
    scaler = GradScaler()
    m.train()
    pbar = tqdm(total=epoch)
    ticker_train = []
    ticker_inference = []
    metric_test = 0
    for e in range(epoch):
        optimizer.zero_grad()
        with autocast():
            pre_y = m(dataset)
            loss = critical(pre_y[dataset.train_mask], dataset.y[dataset.train_mask])
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # reduce_schedule.step(e)
        with torch.no_grad():
            f1_test = metric(m, dataset)
            metric_test = max(metric_test, f1_test)
            pbar.set_description(
                f'epoch: {e + 1} \\ {epoch}, loss_train: {loss}, best_f1: {metric_test}')
            pbar.update()
    pbar.close()
    return metric_test


if __name__ == '__main__':
    res = []
    cfg = {"in_channels": 3703,
           "hidden_feature": 256,
           "out_channels": 6,
           "optim":
               {"lr": 0.0005, "weight_decay": 1e-1},
           "type": "multi_classify",
           "epoch": 100,
           }
    dataset = citeseer()
    res = []
    for exp in range(5):
        m = GraphSage(**cfg).to(device)
        met = fit(m, cfg, dataset)
        res.append(met)
    print(np.mean(res), np.std(res))
    print(get_parameter_number(m))
