from config import config
args = config.parse()
import ipdb
import random
from data_hyper import data
#!import code; code.interact(local=vars())
# seed
import copy
import os, torch, numpy as np
import time
from model import networks
import torch, os, numpy as np, scipy.sparse as sp
import torch.optim as optim, torch.nn.functional as F
from torch.autograd import Variable
from tqdm import tqdm
from model import utils

import torch.multiprocessing as mp


torch.manual_seed(args.seed)
np.random.seed(args.seed)
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)



# gpu, seed
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"        
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
os.environ['PYTHONHASHSEED'] = str(args.seed)

# load data
dataset, train, test = data.load(args)
#print("length of train is", len(train))

#wandb.init(config=args)
HyperTensorGCN = {}
E =  dataset['hypergraph']
X, Y = dataset['features'], dataset['labels']

# hypergcn and optimiser
#hardcoded args.c = Y.shape[1]
args.d, args.c = X.shape[1], Y.shape[1]
hypertensorgcn = networks.HyperTensorGCN(E, X, args)
optimiser = optim.Adam(list(hypertensorgcn.parameters()), lr=args.rate, weight_decay=args.decay)
#scheduler = optim.lr_scheduler.ExponentialLR(optimiser, gamma=0.95)

scheduler = optim.lr_scheduler.StepLR(optimiser, 200, gamma=0.2, last_epoch=-1)

def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = np.diag(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

X = normalize(X)
X = torch.FloatTensor(np.array(X))

Y = np.array(Y)
Y = torch.LongTensor(np.where(Y)[1])

idx_train = torch.LongTensor(train)
idx_test = torch.LongTensor(test)
# cuda
args.Cuda = True and torch.cuda.is_available()
if args.Cuda:
    hypertensorgcn.cuda()
    X, Y = X.cuda(), Y.cuda()
    idx_test = idx_test.cuda()
#    idx_val = idx_val.cuda()
    idx_train = idx_train.cuda()
    for key, value in E.items():
        E[key] = torch.Tensor(list(E[key])) .cuda()


def train(epoch):
    t = time.time()
    hypertensorgcn.train()
    optimiser.zero_grad()
    output = hypertensorgcn(E,X)
    loss_train = F.nll_loss(output[idx_train], Y[idx_train])
    #wandb.log({"loss": loss_train})
    acc_train = accuracy(output[idx_train], Y[idx_train])
    #ipdb.set_trace()
    loss_train.backward()
    #ipdb.set_trace()
    optimiser.step()



def test():
    hypertensorgcn.eval()
    output = hypertensorgcn(E,X)
    loss_test = F.nll_loss(output[idx_test], Y[idx_test])
    acc_test = accuracy(output[idx_test], Y[idx_test])
    print(acc_test.item())
    #print("Test set results:",
    #      "loss= {:.4f}".format(loss_test.item()),
    #      "accuracy= {:.4f}".format(acc_test.item()))
    return loss_test.item(),acc_test.item()

def accuracy(Z, Y):
    """
    arguments:
    Z: predictions
    Y: ground truth labels

    returns: 
    accuracy
    """
    
    predictions = Z.max(1)[1].type_as(Y)
    correct = predictions.eq(Y).double()
    correct = correct.sum()

    accuracy = correct / len(Y)
    return accuracy


# Train model
t_total = time.time()
for epoch in range(args.epochs):
    train(epoch)


test()







