# Dataset index:
# Cora:1, Citeseer: 2, Pubmed: 3, Physics: 4, CS: 5, Computers: 6
# Photos: 7, Reddit: 8, Github_social: 9, Twitch.DE: 10, Twitch.FR: 11, Wiki.Croc: 12
# Wiki.squirrel: 13


import torch
import torch.nn as nn
import torch.nn.functional as F
import DataLoader
from DataLoader import load_data
from SparseGAT import sparseGAT, evaluate
import time
import numpy as np
import matplotlib.pyplot as plt

# device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Number of target classes
num_classes = [7, 6, 3, 5, 15, 10, 8, 41, 2, 2, 2, 11631, 5201]

# Number of trials to average metrics over
num_trials = 3

# Choice values
choice_vals = [4]

filenames = ['V_Cora.csv', 'citeseer_Reff.txt', 'V_Pubmed.csv', 'V_Phy.csv', 'V_CS.csv', 'Amazon_computers_Reff.txt',
             'Amazon_Photo_Reff.txt', 'V_R_eff_Reddit.csv', 'V_git.csv', 'V_twitch_DE.csv', 'V_twitch_FR.csv',
             'V_wiki_crocs.csv', 'V_wiki_squirrels.csv']


def sigma(x):
    return 1. / (1 + torch.exp(-1 * x))


def g(x):
    return min(1, max(x, 0))


def spGAT(g, features, num_classes, device, labels, mask, test_mask, val_mask):
    beta = 1 / 4
    gamma = -0.1
    zeta = 1.1
    lam = 0.5
    # create the model, 2 heads, each head has hidden size 8
    net = sparseGAT(g,
                    in_dim=features.size()[1],
                    hidden_dim=8,
                    out_dim=num_classes,
                    num_heads=8)

    net.to(device)

    # create optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)

    # main loop
    dur = []
    num_epochs = 200
    loss_log = []
    acc_log = []
    train_acc_log = []
    print("starting epochs...")
    max_acc = 0
    alpha_log = []
    loss_subtract_term = beta * np.log(-1 * gamma / zeta)
    for epoch in range(num_epochs):
        if epoch >= 1:
            t0 = time.time()

        logits = net(features)
        logp = F.log_softmax(logits, 1)
        #        print("Fidelity Loss:", F.nll_loss(logp[mask], labels[mask].to(device)))
        #        print("Reg loss:", lam*sum(sigma(net.logalpha - beta*np.log(-1*gamma/zeta))))

        loss = F.nll_loss(logp[mask], labels[mask].to(device)) + lam * torch.sum(
            sigma(net.logalpha - loss_subtract_term))

        loss_log.append(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 1:
            dur.append(time.time() - t0)

        acc = evaluate(net, g, features, labels, val_mask)
        train_acc = evaluate(net, g, features, labels, mask)

        acc_log.append(acc)
        train_acc_log.append(train_acc)

        alpha_log.append(net.logalpha.cpu().detach().numpy())

        if acc > max_acc and epoch >= 20:
            max_acc = acc
            torch.save(net.state_dict(), "./sparseGAT_max_acc")

        print("Epoch {:05d} | Loss {:.4f} | Train Acc {:.4f} | Val Accuracy {:.4f}| Time(s) {:.4f}".format(
            epoch, loss.item(), train_acc, acc, np.mean(dur)))
    net_opt = sparseGAT(g,
                        in_dim=features.size()[1],
                        hidden_dim=8,
                        out_dim=num_classes,
                        num_heads=8)

    net_opt.load_state_dict(torch.load("./sparseGAT_max_acc"))
    test_acc = evaluate(net_opt, g, features, labels, val_mask)

    return loss.item(), acc, np.mean(dur), loss_log, acc_log, train_acc_log, test_acc, net_opt, alpha_log


beta = 1 / 4
gamma = -0.1
zeta = 1.1

for choice in choice_vals:
    print("Loading data from ", filenames[choice - 1])
    # load data
    g, features, labels, mask, test_mask, N, Ne, val_mask = load_data(choice)

    acc_spGAT_trials = []
    dur_spGAT_trials = []

    # Change the device
    g.to(device)
    features.to(device)
    labels.to(device)
    mask.to(device)
    test_mask.to(device)
    val_mask.to(device)

    for trials in range(num_trials):
        loss_GAT, acc_GAT, dur_GAT, loss_log, acc_log, train_acc_log, test_acc, net_opt, alpha_log = spGAT(g, features,
                                                                                                           num_classes[
                                                                                                               choice - 1],
                                                                                                           device,
                                                                                                           labels, mask,
                                                                                                           test_mask,
                                                                                                           val_mask)

        acc_spGAT_trials.append(test_acc)
        dur_spGAT_trials.append(dur_GAT)

    mean_acc = np.mean(acc_spGAT_trials)
    mean_dur = np.mean(acc_spGAT_trials)
    var = np.std(acc_spGAT_trials)

    filename = 'F1_SparseGAT_results_with_var_' + filenames[choice - 1] + '.npz'
    np.savez(filename, mean_acc, mean_dur, var)



