from __future__ import division
from __future__ import print_function
import time
import random
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from utils import *
import torch.nn as nn
from sklearn.metrics import f1_score
from dmp_conv import *
import uuid

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--epochs', type=int, default=8000)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--hidden', type=int, default=128)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--heads', type=int, default=16)
parser.add_argument('--output_heads', type=int, default=1)
parser.add_argument('--model', type=int, default=1)
parser.add_argument('--lambda_', type=float, default=2)
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--dev', type=int, default=0, help='device id')
# parser.add_argument('--epochs', type=int, default=8000, help='Number of epochs to train.')
parser.add_argument('--patience', type=int, default=2000, help='Patience')
parser.add_argument('--test', action='store_true', default=False, help='evaluation on test set.')


args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

cudaid = "cuda:"+str(args.dev)
device = torch.device(cudaid)

checkpt_file = f'pretrained/{args.model}/'+uuid.uuid4().hex+'.pt'
checkpt_file = checkpt_file + f'{args.hidden}_{args.heads}_{args.num_layers}'
checkpt_file = checkpt_file + f'_{args.dropout}_{args.lr}_{args.lambda_}'
train_adj,val_adj,test_adj,train_feat,val_feat,test_feat,train_labels,val_labels, test_labels,train_nodes, val_nodes, test_nodes = load_ppi()

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.ModuleList()
        self.conv1 = conv_to_use(
            train_feat.shape[-1],
            args.hidden,
            heads=args.heads,
            dropout=args.dropout, lambda_=args.lambda_)
        self.layers.append(self.conv1)

        for i in range(args.num_layers-2):
            self.layers.append(DMPConv(args.hidden * args.heads,
                    args.hidden,
                    heads=args.heads,
                    dropout=args.dropout,
                    lambda_=args.lambda_))

        self.conv2 = conv_to_use(
            args.hidden * args.heads,
            train_labels.shape[-1],
            heads=args.output_heads,
            concat=False,
            dropout=args.dropout, lambda_=args.lambda_)
        self.layers.append(self.conv2)

        # self.bn = nn.ModuleList()
        # for i in range(args.num_layers-1):
        #     self.bn.append(torch.nn.BatchNorm1d(args.hidden * args.heads))

    # def reset_parameters(self):
    #     self.conv1.reset_parameters()
    #     self.conv2.reset_parameters()

    def reset_parameters(self):
        def weight_reset(m):
            if isinstance(m, Linear):
                m.reset_parameters()
        self.apply(weight_reset)

    def forward(self, x, edge_index):
        for ii, layer in enumerate(self.layers):
            if ii == 0 or ii == args.num_layers-1:
                x = layer(x, edge_index)
            else:
                x = layer(x, edge_index) + x
                # x = layer(x, edge_index)
            if ii != args.num_layers - 1:
                x = F.elu(x)
                x = F.dropout(x, p=args.dropout, training=self.training)
        return torch.sigmoid(x)
        # return F.log_softmax(x, dim=1)

    # def forward(self, x, edge_index):
    #     x = F.elu(self.conv1(x, edge_index))
    #     x = F.dropout(x, p=args.dropout, training=self.training)
    #     x = self.conv2(x, edge_index)
    #     return F.sigmoid(x)
    #     # return F.log_softmax(x, dim=1)

    # def forward(self, data):
    #     x, edge_index = data.x, data.edge_index
    #     x = F.elu(self.conv1(x, edge_index))
    #     x = F.dropout(x, p=args.dropout, training=self.training)
    #     x = self.conv2(x, edge_index)
    #     return F.log_softmax(x, dim=1)


model = Net().to(device)
# model.load_state_dict(torch.load(checkpt_file))

optimizer = optim.Adam(model.parameters(), lr=args.lr,
                           weight_decay=args.weight_decay)


loss_fcn = torch.nn.BCELoss()
# adapted from DGL
def evaluate(feats, model, idx ,subgraph, labels, loss_fcn):
    model.eval()
    with torch.no_grad():
        output = model(feats, subgraph)
        loss_data = loss_fcn(output[:idx], labels[:idx].float())
        predict = np.where(output[:idx].data.cpu().numpy() > 0.5, 1, 0)
        score = f1_score(labels[:idx].data.cpu().numpy(),predict, average='micro')
        return score, loss_data.item()


idx = torch.LongTensor(range(20))
loader = Data.DataLoader(dataset=idx,batch_size=1,shuffle=True,num_workers=0)

def train():
    model.train()
    loss_tra = 0
    acc_tra = 0
    for step,batch in enumerate(loader):
        batch_adj = train_adj[batch[0]].to(device)
        batch_adj = batch_adj._indices()
        batch_feature = train_feat[batch[0]].to(device)
        batch_label = train_labels[batch[0]].to(device)
        optimizer.zero_grad()
        output = model(batch_feature,batch_adj)
        # edge_index = batch_adj._indices().to(device)
        # output = model(batch_feature,edge_index)
        loss_train = loss_fcn(output[:train_nodes[batch]], batch_label[:train_nodes[batch]])
        loss_train.backward()
        optimizer.step()
        loss_tra+=loss_train.item()
    loss_tra/=20
    acc_tra/=20
    return loss_tra,acc_tra

def validation():
    loss_val = 0
    acc_val = 0
    for batch in range(2):
        batch_adj = val_adj[batch].to(device)
        batch_adj = batch_adj._indices()
        batch_feature = val_feat[batch].to(device)
        batch_label = val_labels[batch].to(device)
        score, val_loss = evaluate(batch_feature, model, val_nodes[batch] ,batch_adj, batch_label, loss_fcn)
        loss_val+=val_loss
        acc_val += score
    loss_val/=2
    acc_val/=2
    return loss_val,acc_val

def test(load_best=True):
    if load_best:
        model.load_state_dict(torch.load(checkpt_file))
    loss_test = 0
    acc_test = 0
    for batch in range(2):
        batch_adj = test_adj[batch].to(device)
        batch_adj = batch_adj._indices()
        batch_feature = test_feat[batch].to(device)
        batch_label = test_labels[batch].to(device)
        score,loss =evaluate(batch_feature, model,test_nodes[batch], batch_adj, batch_label, loss_fcn)
        loss_test += loss
        acc_test += score
    acc_test/=2
    loss_test/=2
    return loss_test,acc_test

t_total = time.time()
bad_counter = 0
# acc = 0
acc = 0.93
best_epoch = 0
for epoch in range(args.epochs):
    loss_tra,acc_tra = train()
    loss_val,acc_val = validation()
    if acc_val > 0.96:
        loss_test,acc_test = test(load_best=False)
    else:
        loss_test, acc_test = 0, 0

    if(epoch+1)%1 == 0:
        print('Epoch:{:04d}'.format(epoch+1),
            'train',
            'loss:{:.3f}'.format(loss_tra),
            '| val',
            'loss:{:.3f}'.format(loss_val),
            'f1:{:.3f}'.format(acc_val*100),
            '| test f1:{:.3f}'.format(acc_test*100))

    if acc_val > acc:
        acc = acc_val
        best_epoch = epoch
        torch.save(model.state_dict(), checkpt_file)
        bad_counter = 0
    else:
        bad_counter += 1

    if bad_counter == args.patience:
        break

if args.test:
    acc = test()[1]

print("Train cost: {:.4f}s".format(time.time() - t_total))
print('Load {}th epoch'.format(best_epoch))
print("Test" if args.test else "Val","f1.:{:.2f}".format(acc*100))



