import os
import glob
import time
import random
# import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from utils import load_data
from models import Structuralsubspace


class Trainingsubspace:
    def __init__(self, lambd, struc):

        self.struc = struc
        self.seed = 72
        self.lr = 0.01
        self.weight_decay = 5e-4
        self.lambd = lambd
        self.data = 'cora'
        self.r1 = 1
        self.r2 = 2
        self.v1 = 1
        self.v2 = 2
        self.t1 = 1
        self.t2 = 2

        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)

        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)

        # load data
        path = "./data/" + self.data + "/"
        adj, features, labels, idx_train, idx_val, idx_test = load_data(path, self.data, self.r1, self.r2,
                                                                        self.v1, self.v2, self.t1, self.t2)
        if self.struc == 'adj':
            adj = torch.where(adj > 0, torch.ones_like(adj), adj)

        if self.struc == 'similarity':
            sim = torch.matmul(adj.t(), adj) + adj
            d = torch.diag(sim)
            d = torch.unsqueeze(d, dim=1)
            d = (d ** 0.5) + 1e-9
            sim = torch.div(sim, d)
            sim = torch.div(sim, d.t())
            sim = sim - torch.diag(torch.diag(sim), 0)
            adj = torch.where(adj > 0, sim, adj)
            sadj = torch.unsqueeze(torch.sum(adj, dim=1), dim=1) + 1e-9
            adj = torch.div(adj, sadj)

        # edge: edge ids of adj, edget: edge ids of transpose of adj
        # edge_v: non-zero values of adj, edget_v: non-zero values of transpose of adj
        self.edge = (adj.nonzero()).t()
        self.edget = (adj.t().nonzero()).t()
        self.edget_v = torch.masked_select(adj.t(), adj.t() > 0)
        self.edge_v = torch.masked_select(adj, adj > 0)

        # pretrain subspace
        self.ssmodel = Structuralsubspace(features.shape[0])

        if torch.cuda.is_available():
            self.ssmodel.cuda()
            adj = adj.cuda()
            self.edge = self.edge.cuda()
            self.edge_v = self.edge_v.cuda()
            self.edget = self.edget.cuda()
            self.edget_v = self.edget_v.cuda()

        self.edge, self.edge_v, self.edget, self.edget_v = Variable(self.edge), Variable(self.edge_v), \
                                                           Variable(self.edget), Variable(self.edget_v)

        self.reg_optm = optim.Adam(self.ssmodel.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        print('Pretrain structural subspace...\n')
        spadjt = torch.sparse_coo_tensor(self.edget, self.edget_v.squeeze(),
                                        torch.Size([features.shape[0], features.shape[0]]))
        if torch.cuda.is_available():
            spadjt = spadjt.cuda()

        for i in range(200):
            # print('Iteration: '+str(i))
            self.ssmodel.train()
            self.reg_optm.zero_grad()

            selfrep1, ssc_coef1 = self.ssmodel(self.edget, self.edget_v)
            loss_subspace = 0.5 * self.lambd * (torch.norm(selfrep1 - spadjt)).pow(2) + torch.norm(ssc_coef1, p=1)

            loss_subspace.backward()
            self.reg_optm.step()

        print('Structural subspace obtained...\n')

        self.ssmodel.eval()
        selfrep1, ssc_coef1 = self.ssmodel(self.edget, self.edget_v)
        torch.cuda.empty_cache()
        self.ssc_coef1 = torch.masked_select(ssc_coef1.t(), adj > 0)

        ssc = self.ssc_coef1.cpu()
        ssc = ssc.detach()
        ssc = np.array(ssc)
        # ssc = np.array(self.ssc_coef1.cpu().detach())
        np.savetxt('ssc_coef', ssc)
