#!/usr/bin/env python
import argparse, json
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np
import time, datetime
from sklearn.metrics import f1_score
import networkx as nx
import fastergcn.preprocess as preprocess
import scipy.sparse as sp
import metis
import pprint
import math

# Fix random seeds
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)

"""
# Runtime arguments
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--data', type=str, default='cora', help='Dataset name')
parser.add_argument('--hidden_dim', type=int, default=128, help='Size of hidden dimension')
parser.add_argument('--lr', type=float, default=1.0e-4, help='Learning rate')
parser.add_argument('--dropout', type=float, default=0.1, help='Learning rate')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
parser.add_argument('--epoch', type=int, default=10, help='Number of epochs')
parser.add_argument('--use_gcn', action='store_true', help='Whether or not to use GCN aggregator')
parser.add_argument('--cuda', type=int, default=-1, help='Which GPU to use (-1 if using CPU)')
parser.add_argument('--partition_list', nargs='+', default=[], help='List of partition sizes', required=True)
parser.add_argument('--num_layers', type=int, default=2, help='Number of GC layers')
parser.add_argument('--use_cluster', action='store_true', help='Whether or not to use cluster sampling')
parser.add_argument('--weight_decay', type=float, default=0, help='Weight decay (L2 loss on parameters).')
parser.add_argument('--bsize_list', nargs='+', default=[], help='List of balance sizes', required=True)
parser.add_argument('--use_pp', action='store_true', help='Whether or not to use AX preprocessing')
parser.add_argument('--opt', type=str, default='adam', help='Which optimizer to use')
parser.add_argument('--debug', action='store_true', help='Whether or not to print debug messages')
parser.add_argument('--log_id', type=str, default=None, help='Log ID')
parser.add_argument('--diag_lambda', type=float, default=0.0, help='Diagonal weight added to A')

args = parser.parse_args()
logging = vars(args)
"""

# assert(len(args.partition_list) == len(args.bsize_list))

# multitask_data = set(['ppi', 'amazon', 'amazon-0.1', 'amazon-0.3', 'amazon2M', 'amazon2M-47'])
multitask = True  # if args.data in multitask_data else False


class Encoder(nn.Module):
    """Encodes a node's using convolutional GraphSAGE approach"""

    def __init__(self, dim_in, dim_out, dropout, use_gcn=False, use_lynorm=False):
        super(Encoder, self).__init__()
        self.use_gcn = use_gcn
        dim_in = dim_in if use_gcn else 2 * dim_in
        self.linear = nn.Linear(dim_in, dim_out, bias=True)
        self.dropout = dropout
        self.use_lynorm = use_lynorm
        if use_lynorm:
            self.lynorm = nn.LayerNorm(dim_out, elementwise_affine=True)

    def forward(self, adj, feats, prev_H, use_pp):
        if not use_pp:
            agg = torch.spmm(adj, feats)
            if not self.use_gcn:
                agg = torch.cat((agg, feats), 1)
        else:
            agg = feats
            """
            if not self.use_gcn:
                agg = torch.cat((agg, prev_H), 1)
            """

        agg = F.dropout(agg, self.dropout, training=self.training)
        output = self.linear(agg)

        if self.use_lynorm:
            output = self.lynorm(output)

        return output


class EncoderNoFeat(nn.Module):
    """Encodes a node's using convolutional GraphSAGE approach"""

    def __init__(self, dim_in, dim_out, dropout, use_gcn=False, use_lynorm=False):
        super(EncoderNoFeat, self).__init__()
        if use_lynorm:
            self.lynorm = nn.LayerNorm(dim_out, elementwise_affine=True)
        # NOTE: implement bias term
        dim_in = dim_in if use_gcn else 2 * dim_in
        self.dropout = dropout
        self.weight = nn.Parameter(torch.FloatTensor(dim_in, dim_out))
        self.bias = nn.Parameter(torch.FloatTensor(dim_out))
        self.use_lynorm = use_lynorm
        self.reset_parameters()
        if dropout > 0:
            print("dropout on Sparse Tensor not supported")
            exit()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.bias, -bound, bound)

    def forward(self, adj, feats, prev_H, use_pp):
        # NOTE: Assume using PP
        assert use_pp == True
        # NOTE: dropout not supported
        # NOTE: bias not supported as sp.addmm does not handle broadcasting...
        output = torch.spmm(feats, self.weight)
        # output = torch.sparse.addmm(self.bias, feats, self.weight)

        if self.use_lynorm:
            output = self.lynorm(output)
        # output = F.dropout(output, self.dropout, training=self.training)

        return output


class ClusterGCN(nn.Module):
    """Multiple layers of Encoder"""

    def __init__(
        self,
        feature,
        dim_in,
        dim_hidden,
        num_class,
        dropout,
        num_layers=2,
        use_gcn=False,
        use_pp=False,
        use_feat=True,
    ):
        super(ClusterGCN, self).__init__()
        self.feature = feature

        self.num_layers = num_layers
        self.encoders = nn.ModuleList()

        if use_feat:
            self.encoders.append(
                Encoder(dim_in, dim_hidden, dropout, use_gcn=use_gcn, use_lynorm=True)
            )
        else:
            self.encoders.append(
                EncoderNoFeat(
                    dim_in, dim_hidden, dropout, use_gcn=use_gcn, use_lynorm=True
                )
            )
        if num_layers > 2:
            for i in range(num_layers - 2):
                self.encoders.append(
                    Encoder(
                        dim_hidden,
                        dim_hidden,
                        dropout,
                        use_gcn=use_gcn,
                        use_lynorm=True,
                    )
                )
        self.encoders.append(
            Encoder(dim_hidden, num_class, dropout, use_gcn=use_gcn, use_lynorm=False)
        )

        self.use_pp = use_pp
        self.activation = nn.ReLU(inplace=True)
        self.use_feat = use_feat

    def forward(self, info1, info2, adj, feats, orig_feats):
        x = self.encoders[0](adj, feats, orig_feats, self.use_pp)
        x = self.activation(x)
        if self.num_layers > 2:
            for i in range(self.num_layers - 2):
                x = self.encoders[i + 1](adj, x, None, False)
                x = self.activation(x)
        output = self.encoders[self.num_layers - 1](adj, x, None, False)
        return output


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
    )
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def dist_norm(x, y):
    x_norm = x / np.linalg.norm(x)
    y_norm = y / np.linalg.norm(y)
    return np.linalg.norm(x_norm - y_norm)


def dist_dist(x, y):
    from scipy.stats import entropy

    return entropy(x, y)


def renormalize_adj(adj):
    rowsum = np.array(adj.sum(1)).flatten()
    d_inv = 1.0 / (rowsum + 1e-20)
    d_mat_inv = sp.diags(d_inv, 0)
    adj_norm = d_mat_inv.dot(adj)

    return adj_norm


def add_diag(adj, diag_lambda, mask=None):
    norm_vals = np.minimum(1.0, 1.0 / ((adj > 1e-8).sum(axis=1).A1 + 1e-20))

    if mask != None:
        norm_vals[mask] = 0
    adj = adj + sp.diags(norm_vals)
    adj = renormalize_adj(adj)
    # adj = adj + diag_lambda * sp.diags(norm_vals)
    adj = adj + diag_lambda * sp.diags(adj.diagonal())
    # adj += sp.diags(np.ones(adj.shape[0]))
    # NOTE: should we perform re-normalization?
    return adj


def main_new(
    name,
    dim_hidden,
    use_gcn,
    partition_list,
    bsize_list,
    use_cluster,
    use_pp,
    diag_lambda,
):
    # device
    if args.cuda != -1:
        device = torch.device("cuda:{}".format(args.cuda))
    else:
        device = torch.device("cpu")

    partition_list = list(map(int, partition_list))
    bsize_list = list(map(int, bsize_list))
    # Load data
    (
        num_data,
        train_adj,
        full_adj,
        train_part_adj_list,
        full_part_adj,
        feats,
        train_feats,
        test_feats,
        labels_new,
        train,
        val,
        test,
        train_parts_list,
        full_parts,
    ) = preprocess.load_data_new(name, partition_list=partition_list)

    num_features = feats.shape[1]
    num_class = labels_new.shape[1]

    print(labels_new[0])

    if diag_lambda > 0.0:
        print("Add diag lambda {} to all adjs!".format(diag_lambda))
        train_adj = add_diag(train_adj, diag_lambda, list(val) + list(test))
        full_adj = add_diag(full_adj, diag_lambda)
        full_part_adj = add_diag(full_part_adj, diag_lambda)
        train_part_adj_list = [
            add_diag(adj, diag_lambda) for adj in train_part_adj_list
        ]

        # recalculate AX
        train_feats = train_adj.dot(feats)
        test_feats = full_adj.dot(feats)

    print(train_adj[0, :])
    print("full")
    print(full_adj[0, :])
    # print(train_parts)
    # pprint.pprint([sum(labels_new[pt]) for pt in train_parts_list[0]])

    print("train nonzeros", len(train_adj.data))
    print("full nonzeros", len(full_adj.data))

    target_dist = sum(labels_new)

    print("train/val/test: %d/%d/%d" % (len(train), len(val), len(test)))

    if not use_pp:
        print("no PP not yet done")
        return
        train_feats = feats
        test_feats = feats

    if not use_gcn:
        if isinstance(train_feats, sp.csr_matrix):
            train_feats = sp.hstack([train_feats, feats]).tocsr()
            test_feats = sp.hstack([test_feats, feats]).tocsr()
        else:
            train_feats = np.hstack([train_feats, feats])
            test_feats = np.hstack([test_feats, feats])

    real_train_adj = train_adj
    real_full_adj = full_adj
    if partition_list[0] > 1:
        print("using first partitioned graph!")
        train_adj = train_part_adj_list[0]
        full_adj = full_part_adj
    print("train_adj shape", train_adj.shape, real_train_adj.shape)

    adj_lists = []
    if feats.shape == (num_data, num_data):
        use_feat = False
        print("Feature is not used!")
    else:
        use_feat = True
    # Build model
    model = ClusterGCN(
        feats,
        num_features,
        dim_hidden,
        num_class,
        dropout=args.dropout,
        num_layers=args.num_layers,
        use_gcn=use_gcn,
        use_pp=use_pp,
        use_feat=use_feat,
    )
    model.to(device)
    print(
        "current memory after model",
        torch.cuda.memory_allocated(device=device) / 1024 / 1024,
    )

    # Loss function
    if multitask:
        print("Using multi-label loss")
        loss_f = nn.BCEWithLogitsLoss()
        labels = labels_new
    else:
        print("Using multi-class loss")
        loss_f = nn.CrossEntropyLoss()
        labels = labels_new.argmax(axis=1)

    # Iterators
    batch_size = args.batch_size

    # Training
    if args.opt == "adam":
        optimizer = torch.optim.Adam(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay
        )
    elif args.opt == "sgd":
        optimizer = torch.optim.SGD(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay
        )
    else:
        print("Unknown optimizer")
        return

    times = []
    times_neigh = []
    elapsed_times = []
    val_f1 = []
    test_f1 = []

    if multitask:
        # labels = torch.FloatTensor(labels).to(device)
        labels = torch.FloatTensor(labels)
    else:
        # labels = torch.LongTensor(labels).to(device)
        labels = torch.LongTensor(labels)

    print(
        "current memory after label",
        torch.cuda.memory_allocated(device=device) / 1024 / 1024,
    )
    if use_cluster:
        train_set = set(train)
        partition_data_list = []
        for psize, bsize, train_part_adj, train_parts in zip(
            partition_list, bsize_list, train_part_adj_list, train_parts_list
        ):
            only_train_parts = [[] for _ in range(len(train_parts))]
            train_odrs = [[] for _ in range(len(train_parts))]
            part_adjs = []
            train_sparse_feats = []
            train_labels = []
            for pid, part in enumerate(train_parts):
                for nid, nd in enumerate(part):
                    if nd in train_set:
                        only_train_parts[pid].append(nd)
                        train_odrs[pid].append(nid)
                # should be train_adj
                part_adjs.append(
                    sparse_mx_to_torch_sparse_tensor(
                        train_part_adj[part, :][:, part]
                    ).to(device)
                )

                train_labels.append(labels[part])
                if not use_feat:
                    train_sparse_feats.append(
                        sparse_mx_to_torch_sparse_tensor(train_feats[part, :]).to(
                            device
                        )
                    )

            # print("train", list(map(len, only_train_parts)))

            partition_data = dict()
            partition_data["partition_size"] = psize
            partition_data["balance_size"] = bsize
            partition_data["only_train_parts"] = only_train_parts
            partition_data["train_parts"] = train_parts
            partition_data["train_odrs"] = train_odrs
            partition_data["train_labels"] = train_labels
            partition_data["part_adjs"] = part_adjs
            if not use_feat:
                partition_data["train_sparse_feats"] = train_sparse_feats

            partition_data_list.append(partition_data)

        val_set = set(val)
        val_parts = [[] for _ in range(len(full_parts))]
        val_odrs = [[] for _ in range(len(full_parts))]
        val_part_adjs = []
        for pid, part in enumerate(full_parts):
            for nid, nd in enumerate(part):
                if nd in val_set:
                    val_parts[pid].append(nd)
                    val_odrs[pid].append(nid)
            val_part_adjs.append(
                sparse_mx_to_torch_sparse_tensor(full_adj[part, :][:, part]).to(device)
            )

        print("cluster sampling is enabled. arg batch size is not used")
    print(
        "current memory after part",
        torch.cuda.memory_allocated(device=device) / 1024 / 1024,
    )
    idx_full_parts = list(range(len(full_parts)))
    print("start optimizing...")

    print(type(train_feats))
    # TODO: Handle the case when train_feats is a sparse matrix
    if use_feat and isinstance(train_feats, sp.csr_matrix):
        print("Sparse train_feats case not done yet")
        return
        feats = feats.toarray()
        train_feats = train_feats.toarray()
        test_feats = test_feats.toarray()

    print(
        "current memory before feats",
        torch.cuda.memory_allocated(device=device) / 1024 / 1024,
    )
    if use_feat:
        # feats = torch.from_numpy(feats).float().to(device)
        feats = torch.from_numpy(feats).float()
        # train_feats = torch.from_numpy(train_feats).float().to(device)
        train_feats = torch.from_numpy(train_feats).float()
        # test_feats = torch.from_numpy(test_feats).float().to(device)
        test_feats = torch.from_numpy(test_feats).float()

    # real_full_batch_adj = sparse_mx_to_torch_sparse_tensor(real_full_adj).to(device)
    real_full_batch_adj = sparse_mx_to_torch_sparse_tensor(real_full_adj)

    print("current memory", torch.cuda.memory_allocated(device=device) / 1024 / 1024)
    print(
        "current cached memory", torch.cuda.memory_cached(device=device) / 1024 / 1024
    )

    idx_parts = list(range(partition_list[0]))
    for epoch in range(args.epoch):

        if use_cluster:
            nowp = 0
            partition_data = partition_data_list[nowp]

            psize = partition_data["partition_size"]
            bsize = partition_data["balance_size"]
            only_train_parts = partition_data["only_train_parts"]
            train_parts = partition_data["train_parts"]
            train_odrs = partition_data["train_odrs"]
            train_labels = partition_data["train_labels"]
            part_adjs = partition_data["part_adjs"]
            if not use_feat:
                train_sparse_feats = partition_data["train_sparse_feats"]

        print(
            "current memory", torch.cuda.memory_allocated(device=device) / 1024 / 1024
        )
        print(
            "current cached memory",
            torch.cuda.memory_cached(device=device) / 1024 / 1024,
        )
        np.random.shuffle(idx_parts)

        if bsize > 1:
            # """ Randomly choose bsize clusters
            for i, st in enumerate(range(0, psize, bsize)):
                start_time = time.time()
                optimizer.zero_grad()

                batch = []
                # idx_parts[st:st+bsize]
                for pt_idx in idx_parts[st : st + bsize]:
                    batch += only_train_parts[pt_idx]

                # NOTE: this works when clusters contain only training nodes
                batch_neigh = batch

                # NOTE: Better to do renormalization when adding off-diagonal edges
                # batch_adj = sparse_mx_to_torch_sparse_tensor(renormalize_adj_ver2(real_train_adj[batch_neigh, :][:,batch_neigh])).to(device)
                batch_adj = sparse_mx_to_torch_sparse_tensor(
                    renormalize_adj(real_train_adj[batch_neigh, :][:, batch_neigh])
                ).to(device)

                # Non-renormalization version
                # batch_adj = sparse_mx_to_torch_sparse_tensor(real_train_adj[batch_neigh, :][:,batch_neigh]).to(device)

                # Not adding off-diagonal edges
                # batch_adj = sparse_mx_to_torch_sparse_tensor(train_part_adj_list[nowp][batch_neigh, :][:,batch_neigh]).to(device)

                if isinstance(train_feats, sp.csr_matrix):
                    batch_feats = sparse_mx_to_torch_sparse_tensor(
                        train_feats[batch_neigh, :]
                    ).to(device)
                    batch_orig_feats = None
                else:
                    batch_feats = train_feats[batch_neigh, :].to(device)
                    batch_orig_feats = feats[batch_neigh, :].to(device)

                times_neigh.append(time.time() - start_time)

                # forward prop
                logits = model(
                    (None, None), (None, None), batch_adj, batch_feats, batch_orig_feats
                )
                # loss = loss_f(logits, labels[batch])
                loss = loss_f(logits, labels[batch].to(device))

                # backward
                loss.backward()
                optimizer.step()

                end_time = time.time()
                times.append(end_time - start_time)

                if args.debug and (i % 20 == 0):
                    micro, macro = calc_f1(
                        labels[batch].cpu(), logits.cpu().data.numpy(), multitask
                    )
                    print(
                        "batch %d loss %f F1 mi %f ma %f time %f time_neigh %f batch_neigh %d"
                        % (
                            i,
                            loss.item(),
                            micro,
                            macro,
                            times[-1],
                            times_neigh[-1],
                            len(batch_neigh),
                        )
                    )
        else:
            # 1 cluster
            for i, pid in enumerate(idx_parts):
                if len(only_train_parts[pid]) == 0:
                    continue
                start_time = time.time()
                optimizer.zero_grad()
                batch_nodes = only_train_parts[pid]
                cluster_nodes = train_parts[pid]
                mapping = train_odrs[pid]
                # batch_labels = labels[batch_nodes]
                batch_labels = train_labels[pid].to(device)

                # batch_adj = sparse_mx_to_torch_sparse_tensor(full_adj[cluster_nodes,:][:,cluster_nodes]).to(device)
                batch_adj = part_adjs[pid]
                if isinstance(train_feats, sp.csr_matrix):
                    # batch_feats = train_sparse_feats[pid].to(device)
                    batch_feats = train_sparse_feats[pid]
                    batch_orig_feats = None
                else:
                    batch_feats = train_feats[cluster_nodes, :].to(device)
                    batch_orig_feats = feats[cluster_nodes, :].to(device)
                    # batch_feats = train_feats[cluster_nodes, :]
                    # batch_orig_feats = feats[cluster_nodes, :]
                times_neigh.append(time.time() - start_time)

                # forward prop
                logits = model(
                    (None, None), (None, None), batch_adj, batch_feats, batch_orig_feats
                )

                # calculate loss
                loss = loss_f(logits[mapping], batch_labels)
                # backward
                loss.backward()
                optimizer.step()

                end_time = time.time()
                times.append(end_time - start_time)
                # print("\n\n")

                if args.debug and i % 20 == 0:
                    # print(batch_labels)
                    micro, macro = calc_f1(
                        batch_labels.cpu(),
                        logits[mapping].cpu().data.numpy(),
                        multitask,
                    )
                    print(
                        "batch %d loss %f F1 %f time %f time_neigh %f"
                        % (i, loss.item(), micro, times[-1], times_neigh[-1])
                    )
        if epoch % 1 == 0:
            model.eval()
            micro, macro = evaluate(
                model,
                val,
                labels,
                real_full_batch_adj,
                test_feats,
                feats,
                device,
                use_feat,
            )
            # micro, macro = evaluate_cluster(model, val, labels, real_full_adj, test_feats, feats,
            #                            val_part_adjs, val_parts, val_odrs, full_parts, idx_full_parts, device, use_feat)
            model.train()

            elapsed_times.append(sum(times))
            val_f1.append(micro)
            print(
                "epoch {} val F1 micro {} macro {} elapsed train time {}".format(
                    epoch, micro, macro, sum(times)
                )
            )
        if len(test) != 0:
            model.eval()
            micro, macro = evaluate(
                model,
                test,
                labels,
                real_full_batch_adj,
                test_feats,
                feats,
                device,
                use_feat,
            )
            print("Test F1 {}".format(micro))
            test_f1.append(micro)
            model.train()

    if len(test) != 0:
        model.eval()
        micro, macro = evaluate(
            model,
            test,
            labels,
            real_full_batch_adj,
            test_feats,
            feats,
            device,
            use_feat,
        )
        print("Test F1 {}".format(micro))
    # micro, macro = evaluate_cluster(model, test, labels, real_full_adj, test_feats, feats, device)
    # print("Test F1 {}".format("N/A"))
    if len(test_f1) == len(val_f1):
        f1_time_pairs = list(zip(val_f1, elapsed_times, test_f1))
    else:
        f1_time_pairs = list(zip(val_f1, elapsed_times))
    pprint.pprint(f1_time_pairs)
    print("Average batch time:", np.mean(times))
    print("Average neigh time:", np.mean(times_neigh))
    print("Total running time:", np.sum(times))

    if args.log_id != None:
        logging["date"] = str(datetime.datetime.now())
        logging["valf1_logs"] = f1_time_pairs

        LOG_PATH = "logs_final.json"
        logs = json.load(open(LOG_PATH, "r"))
        LOG_ID = args.log_id
        if args.log_id in logs:
            print("Log ID conflicts")
            LOG_ID = "_" + LOG_ID
            logs[LOG_ID] = logging
        else:
            logs[LOG_ID] = logging
        print("Save as {} in {}".format(LOG_ID, LOG_PATH))
        json.dump(logs, open(LOG_PATH, "w"), indent=4)
    else:
        print("No logging")


def evaluate_cluster(
    model,
    data,
    labels,
    full_adj,
    test_feats,
    feats,
    val_part_adjs,
    val_parts,
    val_odrs,
    full_parts,
    idx_full_parts,
    device,
    use_feat,
):
    total_pred = []
    total_lab = []
    total_batch = []
    start_time = time.time()
    for i, pid in enumerate(idx_full_parts):
        if len(val_parts[pid]) == 0:
            continue
        batch_nodes = val_parts[pid]
        cluster_nodes = full_parts[pid]
        mapping = val_odrs[pid]
        # batch_labels = labels[batch_nodes].to(device)

        # batch_adj = sparse_mx_to_torch_sparse_tensor(full_adj[cluster_nodes,:][:,cluster_nodes]).to(device)
        batch_adj = val_part_adjs[pid].to(device)
        batch_feats = test_feats[cluster_nodes, :].to(device)
        batch_orig_feats = feats[cluster_nodes, :].to(device)

        # forward prop
        logits = model(
            (None, None), (None, None), batch_adj, batch_feats, batch_orig_feats
        )

        total_pred.append(logits[mapping].cpu().data)
        total_lab.append(labels[batch_nodes])
        total_batch += batch_nodes
    total_pred = np.vstack(total_pred)
    total_lab = np.vstack(total_lab)
    print("cluster", total_pred.shape, "time", time.time() - start_time)

    return calc_f1(total_lab, total_pred, multitask)


def evaluate(model, data, labels, adj, feats, orig_feats, use_feat):

    # forward prop
    with torch.no_grad():
        if use_feat:
            logits = model(
                (None, None), (None, None), adj.cuda(), feats.cuda(), orig_feats.cuda()
            )
        else:
            feats = sparse_mx_to_torch_sparse_tensor(feats).cuda()
            logits = model((None, None), (None, None), adj.cuda(), feats, None)

    total_pred = logits[data].cpu().data.numpy()
    total_lab = labels[data].cpu()

    return calc_f1(total_lab, total_pred, multitask)


def calc_f1(y_true, y_pred, multitask):
    if multitask:
        y_pred[y_pred > 0] = 1
        y_pred[y_pred <= 0] = 0
    else:
        # y_true = np.argmax(y_true, axis=1)
        y_pred = np.argmax(y_pred, axis=1)
    return (
        f1_score(y_true, y_pred, average="micro"),
        f1_score(y_true, y_pred, average="macro"),
    )


if __name__ == "__main__":
    print(args)
    print("now running new version")
    main_new(
        args.data,
        args.hidden_dim,
        args.use_gcn,
        args.partition_list,
        args.bsize_list,
        args.use_cluster,
        args.use_pp,
        args.diag_lambda,
    )
