import torch
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import GCNConv, GATv2Conv, GraphSAGE, GATConv, GIN, TransformerConv
from torch_scatter import scatter_sum
from torch_scatter import scatter
from utils import get_parameter_number
from torch.cuda.amp import autocast as autocast, GradScaler
import numpy as np
import tqdm
from tqdm import tqdm
from molfreesolv_loader import molfreesolv
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

"GIN, Trans, GATv2, GAT, GCN"

device = "cuda"
from ogb.graphproppred.evaluate import Evaluator


class NodeEmbed(torch.nn.Module):

    def __init__(self, out_feature: int):
        super().__init__()
        self.edge_encoder = BondEncoder(emb_dim=out_feature)
        self.node_encoder = AtomEncoder(emb_dim=out_feature)
        self.ban = torch.nn.BatchNorm1d(num_features=out_feature)
        self.act = torch.nn.LeakyReLU()

    def forward(self, x, edge_attr, edge_index):
        x = self.node_encoder(x)
        e_out = self.edge_encoder(edge_attr)
        e_out = scatter(e_out, index=edge_index[0], dim=0, dim_size=x.shape[0], reduce="mean")
        x = x + e_out

        x = self.ban(x)
        x = self.act(x)

        return x


class GraphSage(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GS = GraphSAGE(in_channels=in_channels, out_channels=hidden_feature, hidden_channels=hidden_feature,
                            num_layers=3, dropout=0.5)
        self.act = torch.nn.ReLU()
        self.node_embed = NodeEmbed(hidden_feature)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
        x = self.node_embed(x, edge_attr, edge_index)
        x = self.GS(x, edge_index, edge_attr=edge_attr)
        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.lin(x)
        return x


class GAT(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GAT_1 = GATConv(in_channels=in_channels, out_channels=hidden_feature, concat=True, heads=6, dropout=0.1)
        self.GAT_2 = GATConv(in_channels=6 * hidden_feature, out_channels=hidden_feature, heads=6, dropout=0.1,
                             concat=True)
        self.GAT_3 = GATConv(in_channels=6 * hidden_feature, out_channels=hidden_feature, heads=6, dropout=0.1,
                             concat=False)
        self.node_embed = NodeEmbed(hidden_feature)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)
        self.act = torch.nn.ReLU()

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        x = self.node_embed(x, edge_attr, edge_index)
        x = torch.nn.functional.elu(self.GAT_1(x, edge_index, edge_attr))
        # x = torch.nn.functional.elu(self.GAT_2(x, edge_index, edge_attr))
        x = self.GAT_3(x, edge_index, edge_attr)
        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.lin(x)
        return x


class GATv2(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GATv2_1 = GATv2Conv(in_channels=hidden_feature, out_channels=hidden_feature, concat=True, heads=6,
                                  dropout=0.1)
        self.GATv2_2 = GATv2Conv(in_channels=6 * hidden_feature, out_channels=hidden_feature, concat=True, heads=6,
                                  dropout=0.1)
        self.GATv2_3 = GATv2Conv(in_channels=6 * hidden_feature, out_channels=hidden_feature, concat=False, heads=6,
                                  dropout=0.1)
        self.node_embed = NodeEmbed(hidden_feature)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)
        self.act = torch.nn.ReLU()

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        x = self.node_embed(x, edge_attr, edge_index)
        x = self.GATv2_1(x, edge_index)
        x = torch.relu_(x)
        # x = self.GATv2_2(x, edge_index)
        # x = torch.relu_(x)
        x = self.GATv2_3(x, edge_index)
        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.lin(x)
        return x

class GCN(torch.nn.Module):

    def __init__(self, in_channels: int, hidden_feature: int, out_channels: int, **kwargs):
        super().__init__()

        self.Gcn1 = GCNConv(in_channels=in_channels, out_channels=hidden_feature, cached=False)
        self.Gcn2 = GCNConv(in_channels=hidden_feature, out_channels=hidden_feature, cached=False)
        self.Gcn3 = GCNConv(in_channels=hidden_feature, out_channels=hidden_feature, cached=False)
        self.node_embed = NodeEmbed(hidden_feature)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)
        self.act = torch.nn.ReLU()
    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.size(0))
        x = self.node_embed(x, edge_attr, edge_index)
        x = torch.nn.functional.dropout(x, p=0.3, training=self.training)
        x = self.Gcn1(x, edge_index)
        x = torch.relu_(x)
        # x = torch.nn.functional.dropout(x, p=0.3, training=self.training)
        # x = self.Gcn2(x, edge_index)
        # x = torch.relu_(x)
        x = torch.nn.functional.dropout(x, p=0.3, training=self.training)
        x = self.Gcn3(x, edge_index)
        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.lin(x)
        return x


class TranformerGNN(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.transGnn1 = TransformerConv(in_channels=in_channels, out_channels=hidden_feature, concat=True, heads=6,
                                         dropout=0.5)
        self.transGnn2 = TransformerConv(in_channels=6 * hidden_feature, out_channels=hidden_feature, concat=True,
                                         heads=6, dropout=0.3)
        self.transGnn3 = TransformerConv(in_channels=6 * hidden_feature, out_channels=hidden_feature, concat=False,
                                         heads=6, dropout=0.3)
        self.act = torch.nn.ReLU()
        self.node_embed = NodeEmbed(hidden_feature)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)
        self.act = torch.nn.GELU()

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        x = self.node_embed(x, edge_attr, edge_index)
        x = self.transGnn1(x, edge_index)
        x = self.act(x)
        # x = self.transGnn2(x, edge_index, edge_attr)
        # x = self.act(x)
        x = self.transGnn3(x, edge_index)
        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.lin(x)
        return x


class GIN_(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GNN = GIN(hidden_channels=hidden_feature, in_channels=in_channels, out_channels=hidden_feature,
                       num_layers=3,
                       dropout=0.3)
        self.node_embed = NodeEmbed(hidden_feature)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)
        self.act = torch.nn.GELU()

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.size(0))
        x = self.node_embed(x, edge_attr, edge_index)
        x = self.GNN(x, edge_index)
        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.lin(x)
        return x


def metric(m, train_dataset, test_dataset):
    m.eval()
    evl = Evaluator("ogbg-molesol")
    with torch.no_grad():
        y_true = []
        y_pred = []
        for graphs in test_dataset:
            graphs = graphs.to(device)
            pre_y = m(graphs)
            y_true.append(graphs.y.view(-1, 1).detach().cpu())
            y_pred.append(pre_y.detach().view(-1, 1).cpu())

        y_true = torch.cat(y_true, dim=0).numpy()
        y_pred = torch.cat(y_pred, dim=0).numpy()

        input_dict = {"y_true": y_true, "y_pred": y_pred}
        evl_test = evl.eval(input_dict)
    m.train()
    return evl_test


def fit(m, config: dict, train_dataset, test_dataset, valid_dataset):
    epoch = config.get("epoch", 10)

    optim_config = config.get("optim", {"lr": 0.05, "weight_decay": 5e-4})

    critical = {
        "regress": torch.nn.MSELoss,
        "binary_classify": torch.nn.BCEWithLogitsLoss,
        "multi_classify": torch.nn.CrossEntropyLoss,

    }[config.get("type")]()

    optimizer = torch.optim.Adam(**optim_config, params=m.parameters())

    # reduce_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
    #                                                              mode='min',
    #                                                              factor=0.9,
    #                                                              patience=300,
    #                                                              verbose=False,
    #                                                              threshold=1e-4,
    #                                                              threshold_mode='rel',
    #                                                              cooldown=0,
    #                                                              min_lr=0.0001,
    #                                                              eps=1e-6)

    reduce_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epoch, T_mult=1,
                                                                           eta_min=0.0035)

    m.train()
    pbar = tqdm(total=len(train_dataset) * epoch)
    scaler = GradScaler()
    loss_mean = -1
    sub_res = float("inf")
    for e in range(epoch):
        loss_record = []
        for graphs in train_dataset:
            if len(graphs) == 1:
                continue
            graphs = graphs.to(device)
            optimizer.zero_grad()
            with autocast():
                pre_y = m(graphs)
                loss = critical(pre_y, graphs.y.float())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            pbar.set_description(f'epoch: {e + 1}, loss_train: {loss}, loss_mean: {loss_mean}, best_test: {sub_res}')
            pbar.update()
            loss_record.append(loss.detach().cpu().numpy())

        test_res = metric(m, valid_dataset, test_dataset)["rmse"]
        if test_res < sub_res:
            sub_res = test_res

        loss_mean = np.mean(loss_record)
        reduce_schedule.step(e)

    pbar.close()
    return sub_res


if __name__ == '__main__':
    res = []
    cfg = {
        "hidden_feature": 256,
        "in_channels": 256,
        "out_channels": 1,
        "optim": {"lr": 0.001, "weight_decay": 3e-3},
        "type": "regress",
        "epoch": 150,
    }
    train, valid, test = molfreesolv()
    res = []
    for exp in range(5):
        m = GAT(**cfg).to(device)
        met = fit(m, cfg, train, test, valid)
        res.append(met)
    print(np.mean(res), np.std(res))
    print(get_parameter_number(m))
