import torch
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import GCNConv, global_mean_pool, GATv2Conv, GraphSAGE, GATConv, GIN, TransformerConv
from torch_scatter import scatter_sum

from molpcba_layer import NodeEmbed
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler
from utils import get_parameter_number
import numpy as np
from utils import Grid_Search, get_parameter_number

import tqdm
from tqdm import tqdm
import molpcba_loader

"GIN, Trans, GATv2, GAT, GCN"

device = "cuda"
from sklearn.metrics import f1_score
from ogb.graphproppred.evaluate import Evaluator


class GraphSage(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GS = GraphSAGE(in_channels=in_channels, out_channels=hidden_feature, hidden_channels=hidden_feature,
                            num_layers=2, dropout=0.5)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        # x = self.node_embed(x, edge_attr, edge_index)
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
        x = self.GS(x, edge_index, edge_attr=edge_attr)
        return x


class GAT(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GAT_1 = GATConv(in_channels=hidden_feature, out_channels=hidden_feature, concat=True, heads=8, dropout=0.1)
        self.GAT_2 = GATConv(in_channels=8 * hidden_feature, out_channels=hidden_feature, heads=8, dropout=0.1,
                             concat=True)
        self.GAT_3 = GATConv(in_channels=8 * hidden_feature, out_channels=hidden_feature, heads=1, dropout=0.1)
        self.node_embed = NodeEmbed(hidden_feature)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)
        self.act = torch.nn.GELU()

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.size()[0])
        x = self.node_embed(x, edge_attr, edge_index)
        x = torch.nn.functional.elu(self.GAT_1(x, edge_index, edge_attr))
        x = torch.nn.functional.elu(self.GAT_2(x, edge_index, edge_attr))
        x = self.GAT_3(x, edge_index, edge_attr)
        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.lin(x)
        return x


class GATv2(torch.nn.Module):

    def __init__(self, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GATv2_1 = GATv2Conv(in_channels=hidden_feature, out_channels=hidden_feature, concat=True, heads=6,
                                 )
        self.GATv2_2 = GATv2Conv(in_channels=6 * hidden_feature, out_channels=hidden_feature, concat=True, heads=6,
                                 )
        self.GATv2_3 = GATv2Conv(in_channels=6 * hidden_feature, out_channels=hidden_feature, concat=False, heads=6,
                                 )
        self.node_embed = NodeEmbed(hidden_feature)
        self.act = torch.nn.GELU()
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        x = self.node_embed(x, edge_attr, edge_index)
        x = torch.nn.functional.dropout(x, training=self.training)
        x = self.act(self.GATv2_1(x, edge_index))
        x = self.act(self.GATv2_2(x, edge_index))
        x = self.GATv2_3(x, edge_index)
        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.lin(x)
        return x


class GCN(torch.nn.Module):

    def __init__(self, in_channels: int, hidden_feature: int, out_channels: int, **kwargs):
        super().__init__()

        self.Gcn1 = GCNConv(in_channels=in_channels, out_channels=hidden_feature, cached=True)
        self.Gcn2 = GCNConv(in_channels=hidden_feature, out_channels=hidden_feature, cached=True)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.size(0))
        x = torch.nn.functional.dropout(x, p=0.6, training=self.training)
        x = self.Gcn1(x, edge_index)
        x = torch.relu_(x)
        x = torch.nn.functional.dropout(x, p=0.6, training=self.training)
        x = self.Gcn2(x, edge_index)
        x = torch.relu_(x)
        return x


class TranformerGNN(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.transGnn1 = TransformerConv(in_channels=hidden_feature, out_channels=hidden_feature, concat=True, heads=6,
                                         dropout=0)
        self.transGnn2 = TransformerConv(in_channels=6 * hidden_feature, out_channels=hidden_feature, concat=False,
                                         heads=6, dropout=0)

        self.node_embed = NodeEmbed(hidden_feature)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)

    def forward(self, graph: Data):
        x, edge_index, edge_attr, batch = graph.x, graph.edge_index, graph.edge_attr, graph.batch
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.size()[0])
        x = self.node_embed(x, edge_attr, edge_index)
        x = torch.nn.functional.dropout(x, training=self.training)
        x = self.transGnn1(x, edge_index)
        x = torch.relu_(x)
        x = torch.nn.functional.dropout(x, training=self.training)
        x = self.transGnn2(x, edge_index)
        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = torch.nn.functional.gelu(x)
        x = self.lin(x)
        return x


class GIN_(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, hidden_feature: int, **kwargs):
        super().__init__()

        self.GNN = GIN(hidden_channels=hidden_feature, in_channels=hidden_feature, out_channels=hidden_feature,
                       num_layers=3,
                       dropout=0.1)
        self.node_embed = NodeEmbed(hidden_feature)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden_feature)
        self.lin = torch.nn.Linear(in_features=hidden_feature, out_features=out_channels)

    def forward(self, graph: Data):
        x, batch_size, edge_index, edge_attr, batch = graph.x, graph.y.size(
            0), graph.edge_index, graph.edge_attr, graph.batch

        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.size()[0])
        x = self.node_embed(x, edge_attr, edge_index)
        x = self.GNN(x, edge_index)
        self.batch_norm = self.batch_norm(x)
        x = torch.nn.GELU(x)
        x = self.lin(x)
        return x


def metric(m, train_dataset, test_dataset):
    m.eval()
    evl = Evaluator("ogbg-molpcba")
    with torch.no_grad():
        y_true = []
        y_pred = []
        for graphs in train_dataset:
            graphs = graphs.to(device)
            pre_y = m(graphs)
            y_true.append(graphs.y.view(pre_y.shape).detach().cpu())
            y_pred.append(pre_y.detach().cpu())

        y_true = torch.cat(y_true, dim=0).numpy()
        y_pred = torch.cat(y_pred, dim=0).numpy()

        input_dict = {"y_true": y_true, "y_pred": y_pred}
        evl_train = evl.eval(input_dict)

        y_true = []
        y_pred = []
        for graphs in test_dataset:
            graphs = graphs.to(device)
            pre_y = m(graphs)
            y_true.append(graphs.y.view(pre_y.shape).detach().cpu())
            y_pred.append(pre_y.detach().cpu())

        y_true = torch.cat(y_true, dim=0).numpy()
        y_pred = torch.cat(y_pred, dim=0).numpy()

        input_dict = {"y_true": y_true, "y_pred": y_pred}
        evl_test = evl.eval(input_dict)

        print(f"test: {evl_test},  train: {evl_train}")
    m.train()
    return evl_test


def fit(m, config: dict, train_dataset, test_dataset, valid_dataset, lr_statue=None):
    epoch = config.get("epoch", 10)

    optim_config = config.get("optim", {"lr": 0.05, "weight_decay": 5e-4})

    if lr_statue is None:
        optimizer = torch.optim.Adam(**optim_config, params=m.parameters())
    else:
        optimizer = torch.optim.Adam(**optim_config, params=m.parameters())
        optimizer.load_state_dict(lr_statue)

    critical = {
        "regress": torch.nn.MSELoss,
        "binary_classify": torch.nn.BCEWithLogitsLoss,
        "multi_classify": torch.nn.CrossEntropyLoss,

    }[config.get("type")]()

    reduce_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                 mode='min',
                                                                 factor=0.9,
                                                                 patience=450,
                                                                 verbose=False,
                                                                 threshold=1e-5,
                                                                 threshold_mode='rel',
                                                                 cooldown=0,
                                                                 min_lr=0.00005,
                                                                 eps=1e-6)

    # reduce_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3 * len(train_dataset), T_mult = 1,
    #                                                                        eta_min=0.00005)

    m.train()
    pbar = tqdm(total=len(train_dataset) * epoch)
    scaler = GradScaler()
    loss_mean = -1
    loss_test = -1
    sub_res = 0
    firing_strength = []
    for e in range(epoch):
        loss_record = []

        for graphs in train_dataset:
            graphs = graphs.to(device)
            optimizer.zero_grad()
            is_labeled = graphs.y == graphs.y
            with autocast():
                pre_y = m(graphs)
                loss = critical(pre_y.to(torch.float32)[is_labeled], graphs.y.float()[is_labeled])
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            pbar.set_description(f'epoch: {e + 1}, loss_train: {loss}, loss_mean: {loss_mean}, loss_test: {loss_test}')
            pbar.update()
            loss_record.append(loss.detach().cpu().numpy())
            reduce_schedule.step(loss)

        loss_mean = np.mean(loss_record)

        if (e + 1) & 1:
            continue
        with torch.no_grad():
            if e % 5 == 0:
                test_res = metric(m, train_dataset, test_dataset)["ap"]
            else:
                test_res = metric(m, valid_dataset, test_dataset)["ap"]
            if test_res > sub_res:
                sub_res = test_res
                model_set = m.state_dict()
                lr_set = optimizer.state_dict()
                torch.save({"FL-GNN": model_set, "opt": lr_set, "config": config}, f"molplcba-{test_res}-{e}.tar")
    pbar.close()
    return sub_res


if __name__ == '__main__':
    res = []
    cfg = {"in_channels": [356],
           "hidden_feature": [256],
           "out_channels": [128],
           "optim":
               [{"lr": 0.001, "weight_decay": 1e-5}],
           "type": ["binary_classify"],
           "epoch": [100],
           }

    state_dict = torch.load(r"molplcba-0.24530249296226403-29.tar")["FL-GNN"]
    train, valid, test = molpcba_loader.molpcba()
    for cfg in Grid_Search(cfg):
        # try:
        res = []
        print(cfg)
        for exp in range(1):
            m = GATv2(**cfg)
            m.load_state_dict(state_dict)
            m = m.to(device)
            score = fit(m, cfg, train, test, valid)
            res.append(score)
        print(np.mean(res), np.std(res))
        print(get_parameter_number(m))
    # except:
    #     pass
