import argparse, time, os, pickle
import random

import numpy as np

import dgl
import torch
import torch.optim as optim

import sys
sys.path.append("..")
from models import LANDER
from dataset import LanderDataset

###########
# ArgParser
parser = argparse.ArgumentParser()

# Dataset
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--levels', type=str, default='1')
parser.add_argument('--faiss_gpu', action='store_true')
parser.add_argument('--model_filename', type=str, default='lander.pth')

# KNN
parser.add_argument('--knn_k', type=str, default='10')
parser.add_argument('--num_workers', type=int, default=0)

# Model
parser.add_argument('--hidden', type=int, default=512)
parser.add_argument('--num_conv', type=int, default=1)
parser.add_argument('--dropout', type=float, default=0.)
parser.add_argument('--gat', action='store_true')
parser.add_argument('--gat_k', type=int, default=1)
parser.add_argument('--balance', action='store_true')
parser.add_argument('--use_cluster_feat', action='store_true')
parser.add_argument('--use_focal_loss', action='store_true')

# Training
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=1e-5)

args = parser.parse_args()
print(args)

###########################
# Environment Configuration
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


# setup_seed(20)

##################
# Data Preparation
with open(args.data_path, 'rb') as f:
    path2idx, features, labels, _, masks = pickle.load(f)
    # lidx = np.where(masks==0)
    # features = features[lidx]
    # labels = labels[lidx]
    print("features.shape:", features.shape)
    print("labels.shape:", labels.shape)


k_list = [int(k) for k in args.knn_k.split(',')]
lvl_list = [int(l) for l in args.levels.split(',')]
gs = []
nbrs = []
ks = []
datasets = []
for k, l in zip(k_list, lvl_list):
    print("k:", k)
    print("levels:", l)
    dataset = LanderDataset(features=features, labels=labels, k=k,
                                levels=l, faiss_gpu=args.faiss_gpu)
    gs += [g for g in dataset.gs]
    ks += [k for g in dataset.gs]
    nbrs += [nbr for nbr in dataset.nbrs]
    datasets.append(dataset)

# with open("./dataset.pkl", 'rb') as f:
#     datasets = pickle.load(f)
# for i in range(len(datasets)):
#     dataset = datasets[i]
#     k = k_list[i]
#     gs += [g for g in dataset.gs]
#     ks += [k for g in dataset.gs]
#     nbrs += [nbr for nbr in dataset.nbrs]


with open("./dataset.pkl", 'wb') as f:
    pickle.dump(datasets, f)

print('Dataset Prepared.')

def set_train_sampler_loader(g, k):
    fanouts = [k-1 for i in range(args.num_conv + 1)]
    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
    # fix the number of edges
    train_dataloader = dgl.dataloading.NodeDataLoader(
        g, torch.arange(g.number_of_nodes()), sampler,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=args.num_workers
    )
    return train_dataloader

train_loaders = []
for gidx, g in enumerate(gs):
    train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])
    train_loaders.append(train_dataloader)

##################
# Model Definition
feature_dim = gs[0].ndata['features'].shape[1]
print("feature dimension:", feature_dim)
model = LANDER(feature_dim=feature_dim, nhid=args.hidden,
               num_conv=args.num_conv, dropout=args.dropout,
               use_GAT=args.gat, K=args.gat_k,
               balance=args.balance,
               use_cluster_feat=args.use_cluster_feat,
               use_focal_loss=args.use_focal_loss)
model = model.to(device)
model.train()

#################
# Hyperparameters
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                weight_decay=args.weight_decay)

# keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader = len(train_loaders[0])
train_loaders = [iter(train_loader) for train_loader in train_loaders]
num_loaders = len(train_loaders)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt,
                                                 T_max=args.epochs * num_batch_per_loader * num_loaders,
                                                 eta_min=1e-5)

print('Start Training.')

###############
# Training Loop
for epoch in range(args.epochs):
    loss_den_val_total = []
    loss_conn_val_total = []
    loss_val_total = []
    for batch in range(num_batch_per_loader):
        for loader_id in range(num_loaders):
            try:
                minibatch = next(train_loaders[loader_id])
            except:
                train_loaders[loader_id] = iter(set_train_sampler_loader(gs[loader_id], ks[loader_id]))
                minibatch = next(train_loaders[loader_id])
            input_nodes, sub_g, bipartites = minibatch
            sub_g = sub_g.to(device)
            bipartites = [b.to(device) for b in bipartites]
            # get the feature for the input_nodes
            opt.zero_grad()
            output_bipartite = model(bipartites)
            loss, loss_den_val, loss_conn_val = model.compute_loss(output_bipartite)
            loss_den_val_total.append(loss_den_val)
            loss_conn_val_total.append(loss_conn_val)
            loss_val_total.append(loss.item())
            loss.backward()
            opt.step()
            if (batch + 1) % 10 == 0:
                print('epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'%
                      (epoch, batch, num_batch_per_loader, loader_id, num_loaders,
                       loss.item(), loss_den_val, loss_conn_val))
            scheduler.step()
    print('epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'%
          (epoch, np.array(loss_val_total).mean(),
           np.array(loss_den_val_total).mean(), np.array(loss_conn_val_total).mean()))
    torch.save(model.state_dict(), args.model_filename)

torch.save(model.state_dict(), args.model_filename)
