# Dataset index:
# Cora:1, Citeseer: 2, Pubmed: 3, Physics: 4, CS: 5, Computers: 6
# Photos: 7, Reddit: 8, Github_social: 9, Twitch.DE: 10, Twitch.FR: 11, Wiki.Croc: 12
# Wiki.squirrel: 13


import torch
import torch.nn as nn
import torch.nn.functional as F
import DataLoader
from DataLoader import load_data
import GAT_defn
import time
import numpy as np
import random
import copy

# device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Number of target classes for different datasets
num_classes = [7, 6, 3, 5, 15, 10, 8, 41, 2, 2, 2, 11631, 5201]

# Epsilon values
eps_vals = [0.5, 0.9]  # [0.2, 0.5, 0.75, 0.8]

# Number of trials to average metrics over
num_trials = 3  # 5

# Choice values for different datasets
choice_vals = [4]

# files containing the pre computed resistance values
filenames = ['V_Cora.csv', 'citeseer_Reff.txt', 'V_Pubmed.csv', 'V_Phy.csv', 'V_CS.csv', 'Amazon_computers_Reff.txt',
             'Amazon_photo_Reff.txt', 'V_R_eff_Reddit.csv', 'V_git.csv', 'V_twitch_DE.csv', 'V_twitch_FR.csv',
             'V_wiki_crocs.csv', 'V_wiki_squirrels.csv']


def GAT(g, features, num_classes, device, labels, mask, test_mask, val_mask, num_epochs):
    # create the model, 2 heads, each head has hidden size 8
    net = GAT_defn.GAT(g.to(device),
                       in_dim=features.size()[1],
                       hidden_dim=8,
                       out_dim=num_classes,
                       num_heads=8)

    net.to(device)

    # create optimizer
    lr = 1e-3
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    # main loop
    dur = []
    acc_GAT_log = []

    #    criterion=F.nll_loss()
    print("Starting epochs with lr=", lr, 'for a num_epochs=', num_epochs, '\n')
    max_acc = 0
    for epoch in range(num_epochs):
        if epoch >= 2:
            t0 = time.time()

        logits = net(features)
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp[mask], labels[mask].to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 2:
            dur.append(time.time() - t0)

        acc = GAT_defn.evaluate(net, g, features, labels, val_mask)
        acc_GAT_log.append(acc)

        if acc > max_acc and epoch > 10:
            max_acc = acc
            torch.save(net.state_dict(), "./FastGAT__max_acc")

        if epoch % 20 == 0:
            print("Epoch {:05d} | Loss {:.4f} |  Accuracy {:.4f}| Time(s) {:.4f}".format(
                epoch, loss.item(), acc, np.mean(dur)))
    net_opt = GAT_defn.GAT(g.to(device),
                           in_dim=features.size()[1],
                           hidden_dim=8,
                           out_dim=num_classes,
                           num_heads=8)
    net_opt.load_state_dict(torch.load("./FastGAT__max_acc"))
    test_acc = GAT_defn.evaluate(net_opt, g, features, labels, test_mask)
    print("Epoch {:05d} | Loss {:.4f} | Test Accuracy {:.4f}| Time(s) {:.4f}".format(
        epoch, loss.item(), test_acc, np.mean(dur)))

    return loss.item(), acc, np.mean(dur), acc_GAT_log, test_acc


num_epochs = 200
for choice in choice_vals:

    if choice != 8: # run GAT with full graph on all datasets except Reddit
        full_flag = 1
    else:
        full_flag = 0

    print("Loading data from ", filenames[choice - 1])
    # load data
    g, features, labels, mask, test_mask, N, Ne, val_mask = load_data(choice)

    # Change the device
    g.to(device)
    features.to(device)
    labels.to(device)
    mask.to(device)
    test_mask.to(device)
    val_mask.to(device)

    full_acc_trials = []
    full_dur_trials = []

    if full_flag:
        for trials in range(num_trials):
            print("Using full version, for ", filenames[choice - 1])
            g_deep_copy = copy.deepcopy(g)
            g_deep_copy.add_edges(g_deep_copy.nodes(), g_deep_copy.nodes())

            loss_GAT, acc_GAT, dur_GAT, acc_GAT_log, test_acc_full = GAT(g_deep_copy, features, num_classes[choice - 1],
                                                                         device, labels, mask, test_mask, val_mask,
                                                                         num_epochs)

            full_acc_trials.append(test_acc_full)
            full_dur_trials.append(dur_GAT)

        full_flag = 0
        mean_acc_full = np.mean(full_acc_trials)
        mean_dur_full = np.mean(full_dur_trials)
        full_var = np.std(full_acc_trials)

    rand_seed = random.randint(1, 100000)

    mean_acc_eps = []
    mean_dur_eps = []
    var_acc_eps = []
    Ne_sp_eps = []

    for epsilon in eps_vals:

        acc_GAT_trials = []
        dur_GAT_trials = []

        for trials in range(num_trials):
            g_sp, Ne_sp = DataLoader.generate_spare_graph(g, N, choice, epsilon)
            # g_sp, Ne_sp = DataLoader.generate_randomly_spare_graph(g,N, Ne, epsilon,rand_seed)

            print("eps, trial, Ne, Ne_sp", epsilon, trials, Ne, Ne_sp)

            # Change the device

            g_sp.to(device)

            print("Using Sparse version, for ", filenames[choice - 1])
            loss_sparse_GAT, acc_sparse_GAT, dur_sparse_GAT, acc_sparse_GAT_log, test_acc_sparse = GAT(g_sp, features,
                                                                                                       num_classes[
                                                                                                           choice - 1],
                                                                                                       device, labels,
                                                                                                       mask, test_mask,
                                                                                                       val_mask,
                                                                                                       num_epochs)

            acc_GAT_trials.append(test_acc_sparse)
            dur_GAT_trials.append(dur_sparse_GAT)

        print("For Dataset and random subsampling with epsilon", filenames[choice - 1], epsilon)
        print("Average accuracy:", np.mean(acc_GAT_trials))
        print("Average time/epoch:", np.mean(dur_GAT_trials))

        mean_acc_eps.append(np.mean(acc_GAT_trials))
        mean_dur_eps.append(np.mean(dur_GAT_trials))
        var_acc_eps.append(np.std(acc_GAT_trials))
        Ne_sp_eps.append(Ne_sp)

    filename = 'F1score_results_with_var_' + filenames[choice - 1] + '.npz'
    if choice != 8:
        np.savez(filename, mean_acc_full, mean_dur_full, full_var, mean_acc_eps, mean_dur_eps, var_acc_eps, Ne_sp_eps)
    else:
        np.savez(filename, mean_acc_eps, mean_dur_eps, var_acc_eps, Ne_sp_eps)
#
