import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from tqdm import tqdm

from util import load_data, separate_data
from models.graphcnn import GraphCNN
from igsd import IGSD
from argparser import args
from torch.utils.data import DataLoader, Subset
import networkx as nx

criterion = nn.CrossEntropyLoss()

def load_graphs(graph_ls, idx):
    return [graph_ls[i] for i in idx]

def train(args, model, device, train_graphs, train_diffs, optimizer, epoch): #sup_graphs, sup_diffs, unsup_graphs, unsup_diffs,
    model.train()
    if args.dataset == 'COLLAB':
        #NOTE hard code
        train_graphs = train_graphs[:4100]

    # TODO total_iters = args.iters_per_epoch
    #total_iters = int(len(train_graphs) / args.batch_size * 0.2)
    #total_iters = args.selftrain_iter
    total_iters = args.iters_per_epoch
    pbar = tqdm(range(total_iters), unit='batch')
    num_graph = len(train_graphs)
    #unsup_idx = [i for i in range(int(num_graph*args.unsup_ratio))]
    #sup_idx = [i for i in range(int(num_graph*args.unsup_ratio), num_graph)]
    idx_ls = [i for i in range(num_graph)]
    #random.shuffle(idx_ls)
    unsup_idx = idx_ls[:int(num_graph*args.unsup_ratio)]
    sup_idx = idx_ls[int(num_graph*args.unsup_ratio):] #[i for i in range(int(num_graph*args.unsup_ratio), num_graph)]
    args.num_samples_per_iter = len(unsup_idx) // total_iters #args.selftrain_iter

    unsup_graphs, unsup_diffs = load_graphs(train_graphs, unsup_idx), load_graphs(train_diffs, unsup_idx)
    sup_graphs, sup_diffs = load_graphs(train_graphs, sup_idx), load_graphs(train_diffs, sup_idx)
    print("Num supervised/unsupervised data: {}/{}, Unsup loss:{}, SupCon loss:{}, Self-train:{}"\
          .format(len(sup_graphs), len(unsup_graphs), args.use_unsup_loss, args.use_supcon_loss, args.use_selftrain))
    loss_accum, sup_loss_accum, unsup_loss_accum, supcon_loss_accum = 0, 0, 0, 0
    for pos in pbar:
        #selected_idx = np.random.permutation(len(train_graphs))[:args.batch_size]
        #train_graph = [train_graphs[idx] for idx in selected_idx]
        #train_diff = [train_diffs[idx] for idx in selected_idx]
        #unsup_adj =  [unsup_adjs[idx] for idx in selected_idx]
        #unsup_diff = [unsup_diffs[idx] for idx in selected_idx]

        selftrain_idx = unsup_idx[:args.num_samples_per_iter]  # [i for i in range(args.num_samples_per_iter)]#
        self_graphs = load_graphs(unsup_graphs, [i for i in range(args.num_samples_per_iter)])  # selftrain_idx
        self_diffs = load_graphs(unsup_diffs, [i for i in range(args.num_samples_per_iter)])  # selftrain_idx

        output = model(sup_graphs) #model(train_graph)
        labels = torch.LongTensor([graph.label for graph in sup_graphs]).to(device) #train_graph

        # compute loss
        loss = criterion(output, labels)
        sup_loss_accum += loss.item()
        if args.use_unsup_loss:
            #unsup_loss = model.unsup_loss(unsup_adj, unsup_diff) #NOTE TODO use self training graphs for computing unsup loss
            unsup_loss = model.unsup_loss(self_graphs, self_diffs)
            loss = loss + unsup_loss
            unsup_loss_accum += unsup_loss.item()
        if args.use_supcon_loss and epoch > args.start_supcon_epoch:
            labels = torch.LongTensor([graph.label for graph in sup_graphs]).to(device)
            label_types = torch.unique(labels)
            for label in label_types:
                graphs = [graph for graph in sup_graphs if graph.label==label]
                diffs  = [graph for graph in sup_diffs if graph.label==label]
                sup_labels = torch.LongTensor([graph.label for graph in sup_graphs if graph.label == label]).to(device)
                supcon_loss = model.supcon_loss(graphs, diffs, sup_labels)
                loss = loss + supcon_loss
                supcon_loss_accum += supcon_loss.item()

        # backprop
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if args.use_unsup_loss or args.use_supcon_loss:
                model.update_moving_average()

        loss = loss.detach().cpu().numpy()
        loss_accum += loss

        #pbar.set_description('epoch: %d, Num supervised/unsupervised data: {}/{}'.format(len(sup_graphs), len(unsup_graphs)) % (epoch))
        # self training
        if args.use_selftrain and epoch > args.start_selftrain_epoch:
            #pbar = tqdm(range(args.selftrain_iter), unit='batch')
            #for pos in pbar:

            #inference
            output = pass_data_iteratively(model, self_graphs)
            output = F.softmax(output, dim=-1)
            probs, pred = output.max(1, keepdim=True)
            for i, idx in enumerate(range(args.num_samples_per_iter)): #selftrain_idx
                if probs[i] > args.selftrain_threshold:
                    unsup_graphs[idx].label = pred[i]
            sup_idx += selftrain_idx
            unsup_idx = [x for x in unsup_idx if x not in selftrain_idx]
            #labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
            #correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
            unsup_graphs, unsup_diffs = load_graphs(train_graphs, unsup_idx), load_graphs(train_diffs, unsup_idx)
            sup_graphs, sup_diffs = load_graphs(train_graphs, sup_idx), load_graphs(train_diffs, sup_idx)

            # report
            pbar.set_description('[self training] epoch: %d, Num supervised/unsupervised data: {}/{}'.format(len(sup_graphs), len(unsup_graphs)) % (epoch))
        else:
            pbar.set_description('epoch: %d' % (epoch))

    average_loss = loss_accum / total_iters
    avg_sup_loss = sup_loss_accum / total_iters
    avg_unsup_loss = unsup_loss_accum / total_iters
    avg_supcon_loss = supcon_loss_accum / total_iters
    print("loss training: %f, sup loss: %f, unsup loss: %f, supcon loss: %f" % (average_loss,avg_sup_loss,avg_unsup_loss,avg_supcon_loss))

    return average_loss


###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation)
def pass_data_iteratively(model, graphs, minibatch_size=64):
    model.eval()
    output = []
    idx = np.arange(len(graphs))
    for i in range(0, len(graphs), minibatch_size):
        sampled_idx = idx[i:i + minibatch_size]
        if len(sampled_idx) == 0:
            continue
        output.append(model([graphs[j] for j in sampled_idx]).detach())
    return torch.cat(output, 0)


def test(args, model, device, train_graphs, test_graphs, test_diffs, epoch):
    model.eval()

    output = pass_data_iteratively(model, train_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_train = correct / float(len(train_graphs))

    output = pass_data_iteratively(model, test_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_test = correct / float(len(test_graphs))

    return acc_train, acc_test


def main():
    print(args)
    # set up seeds and gpu device
    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    graphs, diffs, num_classes = load_data(args.dataset, args.degree_as_tag)
    feat_dim = graphs[0].node_features.shape[1]

    ##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx.
    train_graphs, test_graphs, train_diffs, test_diffs = separate_data(graphs, diffs, args.seed, args.fold_idx)
    train_graphs, train_diffs = shuffle(train_graphs, train_diffs)
    test_graphs, test_diffs = shuffle(test_graphs, test_diffs)

    #print(nx.to_numpy_matrix(test_graphs[0].g))
    #print(test_graphs[0].node_features)

    online_encoder = GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim,
                     num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type,
                     args.neighbor_pooling_type, device).to(device)

    sup_encoder = GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1],args.hidden_dim,
                              num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type,
                              args.neighbor_pooling_type, device).to(device)

    model = IGSD(online_encoder, sup_encoder, feat_dim, args.num_layer, projection_size=args.hidden_dim, projection_hidden_size=args.projection_hidden_size) #args.projection_size
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
    best_acc = -1
    for epoch in range(1, args.epochs + 1):
        scheduler.step()

        avg_loss = train(args, model, device, train_graphs, train_diffs, optimizer, epoch) # unsup_graphs, unsup_diffs,
        acc_train, acc_test = test(args, model, device, train_graphs, test_graphs, test_diffs, epoch) #
        if acc_test > best_acc:
            best_acc = acc_test
        print("(%s) accuracy train: %f test: %f best test: %f" % (args.dataset, acc_train, acc_test, best_acc))
        if not args.filename == "":
            with open(args.filename, 'w') as f:
                f.write("%f %f %f" % (avg_loss, acc_train, acc_test))
                f.write("\n")
        print("")

        print(model.online_encoder.eps)

def shuffle(a, b):
    c = list(zip(a,b))
    random.shuffle(c)
    a_out, b_out = zip(*c)
    return a_out, b_out

if __name__ == '__main__':
    main()
