import os
import glob
import time
import random
# import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from utils import load_data, accuracy
from models import SSGAT, Structuralsubspace


class Training:
    def __init__(self, lambd, autobala, roughcons, struc, r1, r2, v1, v2, t1, t2):
        self.fastmode = False
        self.struc = struc
        self.autobala = autobala
        self.roughcons = roughcons
        self.seed = 72
        self.epochs = 1000
        self.lr = 0.01
        self.weight_decay = 5e-4
        self.hidden = 8
        self.lambd = lambd
        self.nb_heads = 8
        self.dropout = 0.6
        self.alpha = 0.2
        self.patience = 150
        self.data = 'cora'
        self.r1 = r1
        self.r2 = r2
        self.v1 = v1
        self.v2 = v2
        self.t1 = t1
        self.t2 = t2

        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)

        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)

        # load data
        path = "./data/" + self.data + "/"
        adj, features, labels, idx_train, idx_val, idx_test = load_data(path, self.data, self.r1, self.r2,
                                                                        self.v1, self.v2, self.t1, self.t2)
        if self.struc == 'adj':
            adj = torch.where(adj > 0, torch.ones_like(adj), adj)

        if self.struc == 'similarity':
            sim = torch.matmul(adj.t(), adj) + adj
            d = torch.diag(sim)
            d = torch.unsqueeze(d, dim=1)
            d = (d ** 0.5) + 1e-9
            sim = torch.div(sim, d)
            sim = torch.div(sim, d.t())
            sim = sim - torch.diag(torch.diag(sim), 0)
            adj = torch.where(adj > 0, sim, adj)
            sadj = torch.unsqueeze(torch.sum(adj, dim=1), dim=1) + 1e-9
            adj = torch.div(adj, sadj)

        # edge: edge ids of adj, edget: edge ids of transpose of adj
        # edge_v: non-zero values of adj, edget_v: non-zero values of transpose of adj
        self.edge = (adj.nonzero()).t()
        self.edget = (adj.t().nonzero()).t()
        self.edget_v = torch.masked_select(adj.t(), adj.t() > 0)
        self.edge_v = torch.masked_select(adj, adj > 0)

        # self.adj = adj
        self.features = features
        self.labels = labels
        self.idx_train = idx_train
        self.idx_val = idx_val
        self.idx_test = idx_test

        # pretrain subspace
        self.ssmodel = Structuralsubspace(self.features.shape[0])

        if torch.cuda.is_available():
            self.ssmodel.cuda()

            self.features = self.features.cuda()
            adj = adj.cuda()
            self.edge = self.edge.cuda()
            self.edge_v = self.edge_v.cuda()
            self.edget = self.edget.cuda()
            self.edget_v = self.edget_v.cuda()

            self.labels = self.labels.cuda()
            self.idx_train = self.idx_train.cuda()
            self.idx_val = self.idx_val.cuda()
            self.idx_test = self.idx_test.cuda()

        self.features, self.labels = Variable(self.features), Variable(self.labels)
        self.edge, self.edge_v, self.edget, self.edget_v = Variable(self.edge), Variable(self.edge_v), \
                                                           Variable(self.edget), Variable(self.edget_v)

        self.reg_optm = optim.Adam(self.ssmodel.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        print('Pretrain structural subspace...\n')
        spadjt = torch.sparse_coo_tensor(self.edget, self.edget_v.squeeze(),
                                        torch.Size([self.features.shape[0], self.features.shape[0]]))
        if torch.cuda.is_available():
            spadjt = spadjt.cuda()

        for i in range(5):
            # print('Iteration: '+str(i))
            self.ssmodel.train()
            self.reg_optm.zero_grad()

            selfrep1, ssc_coef1 = self.ssmodel(self.edget, self.edget_v)
            loss_subspace = 0.5 * self.lambd * (torch.norm(selfrep1 - spadjt)).pow(2) + torch.norm(ssc_coef1, p=1)

            loss_subspace.backward()
            self.reg_optm.step()

        print('Structural subspace obtained...\n')

        self.ssmodel.eval()
        selfrep1, ssc_coef1 = self.ssmodel(self.edget, self.edget_v)
        torch.cuda.empty_cache()
        self.ssc_coef1 = torch.masked_select(ssc_coef1.t(), adj > 0)
        self.ssc_coef1 = Variable(self.ssc_coef1)
        print(self.ssc_coef1.size())
        print(self.ssc_coef1.dtype)

        torch.cuda.empty_cache()

        self.model = SSGAT(nfeat=self.features.shape[1],
                           nhid=self.hidden,
                           nclass=int(self.labels.max()) + 1,
                           dropout=self.dropout,
                           nheads=self.nb_heads,
                           alpha=self.alpha,
                           autobalance=self.autobala,
                           roughconstrain=self.roughcons)
        if torch.cuda.is_available():
            self.model.cuda()

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        t_total = time.time()
        loss_values = []
        bad_counter = 0
        best = 0
        best_epoch = 0

        outputpath = "./output/" + self.data + "/"

        print("Network Fitting...")
        for epoch in range(self.epochs):
            loss_values.append(self.train(epoch))

            torch.save(self.model.state_dict(), outputpath + '{}.pkl'.format(epoch))
            loss_t, acc_t = self.compute_test()

            # if loss_values[-1] < best:
            if acc_t > best:
                # best = loss_values[-1]
                best = acc_t
                best_epoch = epoch
                bad_counter = 0
            else:
                bad_counter += 1

            if bad_counter == self.patience:
                break

            files = glob.glob(outputpath + '*.pkl')
            for file in files:
                # print(file)
                # tmp = file.split('\\')
                tmp = file.split('/')
                # print(tmp[-1])
                # epoch_nb = int(file.split('.')[0])
                # print(tmp[-1].split(".")[0])
                epoch_nb = int(tmp[-1].split('.')[0])
                if epoch_nb < best_epoch:
                    os.remove(file)

        files = glob.glob(outputpath + '*.pkl')
        for file in files:
            # tmp = file.split('\\')
            tmp = file.split('/')
            # print(tmp[-1])
            # epoch_nb = int(file.split('.')[0])
            # print(tmp[-1].split(".")[0])
            epoch_nb = int(tmp[-1].split('.')[0])
            if epoch_nb > best_epoch:
                os.remove(file)

        print("Optimization Finished!")
        print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

        # Restore best model
        print('Loading {}th epoch'.format(best_epoch))
        self.model.load_state_dict(torch.load(outputpath + '{}.pkl'.format(best_epoch)))
        # Testing
        self.best = best
        self.compute_test()

        # Recording model settings and experimental results
        f = open(outputpath + self.data + ".log", 'a+')
        f.write('Time: ' + str(time.asctime(time.localtime(time.time()))) + '\n')
        f.write('Dataset: ' + self.data + '\n')
        f.write('Model settings:\n')
        f.write('Structural data: ')

        if self.struc == 'similarity':
            f.write('Normalized adjacency similarity, ')
        if self.struc == 'adj':
            f.write('Adjacency matrix, ')
        if self.struc != 'similarity' and self.struc != 'adj':
            f.write('Normalized adjacency matrix, ')

        f.write('No. attention heads: ' + str(self.nb_heads) + ', ')
        f.write('Dimension of hidden layer: ' + str(self.hidden) + ', ')
        f.write('Learning rate= ' + str(self.lr) + ', ')
        f.write('Lambda= ' + str(self.lambd) + ', ')
        f.write('Roughly constrained: ' + str(self.roughcons) + ', ')
        f.write('Layer-wise subspaces: ' + str(False) + ', ')
        f.write('Automatically learning weights of structural and feature attention: ' + str(self.autobala) + ',\n')
        f.write('Range of ids of training vertices: ' + str(self.r1) + '-' + str(self.r2) + '\n')
        f.write('Range of ids of testing vertices: ' + str(self.t1) + '-' + str(self.t2) + '\n')
        f.write('Accuracy: ' + str(self.best.data) + '\n\n')
        f.close()

    def train(self, epoch):
        t = time.time()
        self.model.train()
        self.optimizer.zero_grad()

        predict = self.model(self.features, self.edge, self.ssc_coef1)
        loss_train = F.nll_loss(predict[self.idx_train], self.labels[self.idx_train])

        acc_train = accuracy(predict[self.idx_train], self.labels[self.idx_train])
        loss_train.backward()
        self.optimizer.step()
        torch.cuda.empty_cache()

        if not self.fastmode:
            # Evaluate validation set performance separately,
            # deactivates dropout during validation run.
            self.model.eval()
            predict = self.model(self.features, self.edge, self.ssc_coef1)

        loss_val = F.nll_loss(predict[self.idx_val], self.labels[self.idx_val])
        acc_val = accuracy(predict[self.idx_val], self.labels[self.idx_val])
        print('Epoch: {:04d}'.format(epoch + 1),
              'loss_train: {:.4f}'.format(loss_train.data.item()),
              'acc_train: {:.4f}'.format(acc_train.data.item()),
              'loss_val: {:.4f}'.format(loss_val.data.item()),
              'acc_val: {:.4f}'.format(acc_val.data.item()),
              'time: {:.4f}s'.format(time.time() - t))

        return loss_val.data.item()

    def compute_test(self):
        self.model.eval()
        predict = self.model(self.features, self.edge, self.ssc_coef1)

        loss_test = F.nll_loss(predict[self.idx_test], self.labels[self.idx_test])
        acc_test = accuracy(predict[self.idx_test], self.labels[self.idx_test])
        print("Test set results:",
              "loss= {:.4f}".format(loss_test.data),
              "accuracy= {:.4f}".format(acc_test.data))
        return loss_test, acc_test
