import os
import logging
import argparse
import json
import networkx as nx
import numpy as np

import torch
import torch.nn.functional as F
from torch import optim

from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import degree
from torch_geometric.transforms import Compose
from torch_geometric.utils import to_networkx

from model import GNN
 
unlabeled_datasets = ['IMDB-BINARY', 'IMDB-MULTI']

avg_degs = {'MUTAG': [2.21, 2.205, 2.203, 2.209, 2.204, 2.208, 2.21, 2.213, 2.205, 2.213], 
            'ENZYMES': [3.844, 3.833, 3.811, 3.831, 3.825, 3.827, 3.83, 3.836, 3.837, 3.84],
            'NCI1': [2.171, 2.171, 2.171, 2.17, 2.172, 2.17, 2.171, 2.171, 2.171, 2.17],
            'PROTEINS_full': [3.733, 3.737, 3.719, 3.726, 3.744, 3.722, 3.727, 3.732, 3.731, 3.725],
            'IMDB-BINARY': [9.726, 9.562, 9.856, 9.692, 9.982, 9.835, 9.852, 9.709, 9.778, 9.631],
            'IMDB-MULTI': [10.371, 9.823, 10.275, 9.861, 10.12, 10.176, 10.298, 10.163, 10.078, 10.125]}
avg_lambda_max = {'MUTAG': [2.469, 2.466, 2.464, 2.471, 2.464, 2.468, 2.471, 2.473, 2.464, 2.471],
                'ENZYMES': [4.341, 4.343, 4.334, 4.338, 4.338, 4.346, 4.343, 4.345, 4.351, 4.338],
                'NCI1': [2.492, 2.491, 2.491, 2.491, 2.493, 2.492, 2.492, 2.49, 2.492, 2.491],
                'PROTEINS_full': [4.171, 4.174, 4.169, 4.17, 4.166, 4.156, 4.177, 4.167, 4.165, 4.174], 
                'IMDB-BINARY': [10.078, 9.841, 10.138, 9.985, 10.2, 10.062, 10.1, 9.926, 10.074, 9.949],
                'IMDB-MULTI': [8.653, 8.475, 8.672, 8.442, 8.561, 8.559, 8.677, 8.526, 8.558, 8.565]}


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_acc = 0.


    def early_stop(self, validation_acc):
        if validation_acc >= self.max_validation_acc:
            self.max_validation_acc = validation_acc
            self.counter = 0
        elif validation_acc < (self.max_validation_acc + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

class Degree(object):
    def __call__(self, data):
        idx = data.edge_index[0]
        deg = degree(idx, data.num_nodes, dtype=torch.float)
        data.x = deg.unsqueeze(1)
        return data

class GraphOperator(object):
    def __init__(self, norm_factor=1.):
        self.norm_factor = norm_factor

    def __call__(self, data):
        G = to_networkx(data, to_undirected=True)
        G.remove_edges_from(nx.selfloop_edges(G))
        A = nx.to_numpy_array(G)/self.norm_factor
        L,U = np.linalg.eigh(A)
        exp_adj = np.linalg.multi_dot((U, np.diag(np.exp(L)), U.T))
        row, col = np.where(exp_adj>0)
        edge_index_exp_adj = torch.tensor(np.array([row, col]), dtype=torch.long)
        exp_adj_flat = torch.from_numpy(exp_adj[row,col]).unsqueeze(1).float()
        data.edge_index_exp_adj = edge_index_exp_adj
        data.exp_adj_flat = exp_adj_flat
        return data

# Argument parser
parser = argparse.ArgumentParser(description='InvGNN')
parser.add_argument('--dataset', default='MUTAG', help='Dataset name')
parser.add_argument('--lr', type=float, default=1e-3, help='Initial learning rate')
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate')
parser.add_argument('--batch-size', type=int, default=64, help='Batch size')
parser.add_argument('--epochs', type=int, default=500, help='Number of epochs to train')
parser.add_argument('--normalize', default='lambda_max', choices=['avg_deg','lambda_max'], help='How to normalize adjacency')
parser.add_argument('--patience', default=50, help='Patience for early stopping')
args = parser.parse_args()

if os.path.exists("results/output_"+args.dataset+".log"):
    os.remove("results/output_"+args.dataset+".log")

logging.basicConfig(
    filename="results/output_"+args.dataset+".log",
    level=logging.INFO,
    format="%(asctime)s - %(message)s"
)

use_node_attr = False
if args.dataset == 'ENZYMES' or args.dataset == 'PROTEINS_full':
    use_node_attr = True

with open('data_splits/'+args.dataset+'_splits.json','rt') as f:
    for line in f:
        splits = json.loads(line)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def train(epoch, loader, optimizer):
    model.train()
    loss_all = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        loss = F.nll_loss(model(data), data.y)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
    return loss_all / len(loader.dataset)


def val(loader):
    model.eval()
    loss_all = 0
    correct = 0

    for data in loader:
        data = data.to(device)
        output = model(data)
        loss_all += F.nll_loss(output, data.y, reduction='sum').item()
        pred = output.max(1)[1]
        correct += pred.eq(data.y).sum().item()
    return loss_all / len(loader.dataset), correct / len(loader.dataset)


def test(loader):
    model.eval()
    correct = 0

    for data in loader:
        data = data.to(device)
        pred = model(data).max(1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)


acc = []
for i in range(10):
    print('---------------- Split {} ----------------'.format(i))
    logging.info('---------------- Split {} ----------------'.format(i))
    if args.normalize == 'lambda_max':
        norm_factor = avg_lambda_max[args.dataset][i]
    elif args.normalize == 'avg_deg':
        norm_factor = avg_degs[args.dataset][i]

    if args.dataset in unlabeled_datasets:
        dataset = TUDataset(root='./datasets/'+args.dataset, name=args.dataset, transform=Compose([Degree(), GraphOperator(norm_factor)]))
    else:
        dataset = TUDataset(root='./datasets/'+args.dataset, name=args.dataset, use_node_attr=use_node_attr, transform=GraphOperator(norm_factor))

    train_index = splits[i]['model_selection'][0]['train']
    val_index = splits[i]['model_selection'][0]['validation']
    test_index = splits[i]['test']

    test_dataset = dataset[test_index]
    val_dataset = dataset[val_index]
    train_dataset = dataset[train_index]

    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

    best_val_acc = None
    best_hyperparams = {}
    for hidden_dim in [32,64,128]:
        for n_layers in [2,3,4]:
            model = GNN(dataset.num_features, hidden_dim, n_layers, dataset.num_classes, args.dropout).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 

            print('------------- Hyperparameters ------------')
            print('Hidden dim:', hidden_dim, 'num layers:', n_layers)
            logging.info(f"Hyperparameters :: Hidden dim: {hidden_dim}, num layers: {n_layers}")


            early_stopper = EarlyStopper(patience=args.patience)
            for epoch in range(1, args.epochs+1):
                train_loss = train(epoch, train_loader, optimizer)
                val_loss, val_acc = val(val_loader)
                if best_val_acc is None or best_val_acc <= val_acc:
                    best_val_acc = val_acc
                    best_hyperparams['hidden_dim'] = hidden_dim
                    best_hyperparams['n_layers'] = n_layers
                if epoch % 20 == 0:
                    print('Epoch: {:03d}, Train Loss: {:.7f}, Val Loss: {:.7f}, Val Acc: {:.7f}'.format(
                        epoch, train_loss, val_loss, val_acc))
                    logging.info('Epoch: {:03d}, Train Loss: {:.7f}, Val Loss: {:.7f}, Val Acc: {:.7f}'.format(
                        epoch, train_loss, val_loss, val_acc))
                if early_stopper.early_stop(val_acc):
                    break

            print('Best Val Acc:', best_val_acc)
            logging.info(f"Best Val Acc {best_val_acc}")

    print('--------- Best Hyperparameters -----------')
    print('Best Hidden dim:', best_hyperparams['hidden_dim'], 'Best num layers:', best_hyperparams['n_layers'])
    logging.info(f"Best Hyperparameters :: Best hidden dim: {best_hyperparams['hidden_dim']}, Best num layers: {best_hyperparams['n_layers']}")
    
    test_accs = []
    for _ in range(3):
        print('--------------------------------')
        logging.info('---------------------------------')
        model = GNN(dataset.num_features, best_hyperparams['hidden_dim'], best_hyperparams['n_layers'], dataset.num_classes, args.dropout).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 
        
        best_val_acc = None
        early_stopper = EarlyStopper(patience=args.patience)
        for epoch in range(1, args.epochs+1):
            train_loss = train(epoch, train_loader, optimizer)
            val_loss, val_acc = val(val_loader)
            if best_val_acc is None or best_val_acc <= val_acc:
                best_val_acc = val_acc
                test_acc = test(test_loader)
            if epoch % 20 == 0:
                print('Epoch: {:03d}, Train Loss: {:.7f}, Val Acc: {:.7f}, Test Acc: {:.7f}'.format(
                        epoch, train_loss, val_acc, test_acc))
                logging.info('Epoch: {:03d}, Train Loss: {:.7f}, Val Acc: {:.7f}, Test Acc: {:.7f}'.format(
                        epoch, train_loss, val_acc, test_acc))
            if early_stopper.early_stop(val_acc):
                    break

        test_accs.append(test_acc)
    
    print("test accs:", test_accs)
    print("Final test acc:", np.mean(test_accs))
    logging.info(f"Test accs: {test_accs}")
    logging.info(f"Final test acc: {np.mean(test_accs)}")
    acc.append(np.mean(test_accs))
acc = torch.tensor(acc)
print('---------------- Final Result ----------------')
print('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std()))
logging.info('Final Result :: Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std()))
