import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric import utils
from networks import  Net
import torch.nn.functional as F
import argparse
import os
import random
import numpy as np
import sys
from itertools import chain
import pickle

torch.set_num_threads(1)

parser = argparse.ArgumentParser()

parser.add_argument('--seed', type=int, default=777,
                    help='seed')
parser.add_argument('--batch_size', type=int, default=128,
                    help='batch size')
parser.add_argument('--lr', type=float, default=0.0005,
                    help='learning rate')
parser.add_argument('--lr-pool', type=float, default=0.0005,
                    help='learning rate for pooling')
parser.add_argument('--weight_decay', type=float, default=0.0001,
                    help='weight decay')
parser.add_argument('--nhid', type=int, default=128,
                    help='hidden size')
parser.add_argument('--pooling_ratio', type=float, default=0.5,
                    help='pooling ratio')
parser.add_argument('--dropout_ratio', type=float, default=0.5,
                    help='dropout ratio')
parser.add_argument('--dataset', type=str, default='DD',
                    help='DD/PROTEINS/NCI1/NCI109/Mutagenicity')
parser.add_argument('--epochs', type=int, default=100000,
                    help='maximum number of epochs')
parser.add_argument('--patience', type=int, default=50,
                    help='patience for earlystopping')
parser.add_argument('--pooling_layer_type', type=str, default='GCNConv',
                    help='DD/PROTEINS/NCI1/NCI109/Mutagenicity')
parser.add_argument('--gpu', type=str, default="0",
                    help='gpu id')

args = parser.parse_args()
args.device = 'cuda:0'

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False    

dataset = TUDataset(os.path.join('data',args.dataset),name=args.dataset)
args.num_classes = dataset.num_classes
args.num_features = dataset.num_features

def test(model,loader):
    model.eval()
    correct = 0.
    loss = 0.
    for data in loader:
        data = data.to(args.device)
        out = model(data)
        pred = out.max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
        loss += F.nll_loss(out,data.y,reduction='sum').item()
    return correct / len(loader.dataset),loss / len(loader.dataset)

base_model = Net(args).to(args.device)
torch.save(base_model.state_dict(), 'init.pth')
len_dataset = len(dataset)
indices = np.random.permutation(len_dataset)
folds = [indices[(len_dataset * i) // 10:(len_dataset * (i+1)) // 10].tolist() for i in range(10)]
for elem in folds:
    print(sum(elem), len(elem))

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    print(worker_seed)
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
    
for fold_idx in range(10):
    model = Net(args).to(args.device)
    model.load_state_dict(torch.load('init.pth'))
    
    non_pool_params = list(chain.from_iterable([[layer.weight, layer.bias]
                                                for layer in [model.conv1, model.pool1.score_layer,
                                                              model.conv2, model.pool2.score_layer,
                                                              model.conv3, model.pool3.score_layer,
                                                              model.lin1, model.lin2, model.lin3]]))
    pool_params_p = list(chain.from_iterable([[layer.p_pos, layer.p_neg]
                                              for layer in [model.conv1.gnp, model.pool1.score_layer.gnp,
                                                            model.conv2.gnp, model.pool2.score_layer.gnp,
                                                            model.conv3.gnp, model.pool3.score_layer.gnp,
                                                            model.gnp1, model.gnp2, model.gnp3]]))
    pool_params_q = list(chain.from_iterable([[layer.q_pos, layer.q_neg]
                                              for layer in [model.conv1.gnp, model.pool1.score_layer.gnp,
                                                            model.conv2.gnp, model.pool2.score_layer.gnp,
                                                            model.conv3.gnp, model.pool3.score_layer.gnp,
                                                            model.gnp1, model.gnp2, model.gnp3]]))


    print(sum([param.numel() for param in non_pool_params]), sum([param.numel() for param in pool_params_p]) + sum([param.numel() for param in pool_params_q]), sum([param.numel() for param in model.parameters()]))

    trainval_indices = list(chain.from_iterable(folds[:fold_idx] + folds[fold_idx+1:]))
    trainval_indices = np.array(trainval_indices[:])
    np.random.shuffle(trainval_indices)
    trainval_indices = trainval_indices.tolist()
    train_indices = trainval_indices[:(len(trainval_indices) * 9) // 10]
    val_indices = trainval_indices[(len(trainval_indices) * 9) // 10:]
    test_indices = folds[fold_idx]
    assert(len(set(trainval_indices) & set(test_indices)) == 0 and len(set(trainval_indices) ^ set(test_indices)) == len(dataset))
    assert(len(set(train_indices) & set(val_indices)) == 0 and len(set(train_indices) ^ set(val_indices)) == len(trainval_indices))
    print(sum(train_indices), sum(val_indices), sum(test_indices), end=' ')
    
    with open(f"indices/{args.dataset}_{args.seed}_{fold_idx}.pkl", "rb") as f_pkl:
        _expected_train, _expected_val, _expected_test = pickle.load(f_pkl)
        consistency = 0
        for i1, i2 in zip(train_indices, _expected_train):
            if i1 == i2: consistency += 1
        for i1, i2 in zip(val_indices, _expected_val):
            if i1 == i2: consistency += 1
        for i1, i2 in zip(test_indices, _expected_test):
            if i1 == i2: consistency += 1
        print(consistency)
        assert(consistency == len(dataset))
    
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, worker_init_fn=seed_worker)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size,shuffle=False, worker_init_fn=seed_worker)
    
    optimizer = torch.optim.Adam(non_pool_params, lr=args.lr, weight_decay=args.weight_decay)
    optimizer_pool = torch.optim.RMSprop([{'params': pool_params_p, 'lr': args.lr_pool},
                                          {'params': pool_params_q, 'lr': args.lr}],
                                         lr=args.lr, weight_decay=0.)
    min_loss = 1e10
    patience = 0

    for epoch in range(args.epochs):
        model.train()
        for i, data in enumerate(train_loader):
            data = data.to(args.device)
            out = model(data)
            loss = F.nll_loss(out, data.y)
            # print("Training loss:{}".format(loss.item()))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(pool_params_p + pool_params_q, 1000., norm_type=2.0)
            optimizer.step()
            optimizer_pool.step()
            optimizer.zero_grad()
            optimizer_pool.zero_grad()
        val_acc,val_loss = test(model,val_loader)
        print("Validation loss:{}\taccuracy:{}".format(val_loss,val_acc))
        if val_loss < min_loss:
            torch.save(model.state_dict(),'latest.pth')
            # print("Model saved at epoch{}".format(epoch))
            min_loss = val_loss
            patience = 0
        else:
            patience += 1
        if patience > args.patience:
            break 

    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False, worker_init_fn=seed_worker)
    test_model = Net(args).to(args.device)
    test_model.load_state_dict(torch.load('latest.pth'))
    test_acc, test_loss = test(test_model,test_loader)
    print("Test accuarcy for fold {}: {}".format(fold_idx, test_acc))
    torch.save(test_model.state_dict(), f"checkpoints/{args.dataset}_{args.seed}_{fold_idx}.pth")
