import argparse
import os
import pickle
import time

import numpy as np
import torch
import torch.optim as optim
from dataset import LanderDataset
from models import LANDER

import dgl

###########
# ArgParser
parser = argparse.ArgumentParser()

# Dataset
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--test_data_path", type=str, required=True)
parser.add_argument("--levels", type=str, default="1")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--model_filename", type=str, default="lander.pth")

# KNN
parser.add_argument("--knn_k", type=str, default="10")

# Model
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")

# Training
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight_decay", type=float, default=1e-5)

args = parser.parse_args()

###########################
# Environment Configuration
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

##################
# Data Preparation
def prepare_dataset_graphs(data_path, k_list, lvl_list):
    with open(data_path, "rb") as f:
        features, labels = pickle.load(f)
    gs = []
    for k, l in zip(k_list, lvl_list):
        dataset = LanderDataset(
            features=features,
            labels=labels,
            k=k,
            levels=l,
            faiss_gpu=args.faiss_gpu,
        )
        gs += [g.to(device) for g in dataset.gs]
    return gs


k_list = [int(k) for k in args.knn_k.split(",")]
lvl_list = [int(l) for l in args.levels.split(",")]
gs = prepare_dataset_graphs(args.data_path, k_list, lvl_list)
test_gs = prepare_dataset_graphs(args.test_data_path, k_list, lvl_list)

##################
# Model Definition
feature_dim = gs[0].ndata["features"].shape[1]
model = LANDER(
    feature_dim=feature_dim,
    nhid=args.hidden,
    num_conv=args.num_conv,
    dropout=args.dropout,
    use_GAT=args.gat,
    K=args.gat_k,
    balance=args.balance,
    use_cluster_feat=args.use_cluster_feat,
    use_focal_loss=args.use_focal_loss,
)
model = model.to(device)
model.train()
best_model = None
best_loss = np.Inf

#################
# Hyperparameters
opt = optim.SGD(
    model.parameters(),
    lr=args.lr,
    momentum=args.momentum,
    weight_decay=args.weight_decay,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    opt, T_max=args.epochs, eta_min=1e-5
)

###############
# Training Loop
for epoch in range(args.epochs):
    all_loss_den_val = 0
    all_loss_conn_val = 0
    for g in gs:
        opt.zero_grad()
        g = model(g)
        loss, loss_den_val, loss_conn_val = model.compute_loss(g)
        all_loss_den_val += loss_den_val
        all_loss_conn_val += loss_conn_val
        loss.backward()
        opt.step()
    scheduler.step()
    print(
        "Training, epoch: %d, loss_den: %.6f, loss_conn: %.6f"
        % (epoch, all_loss_den_val, all_loss_conn_val)
    )
    # Report test
    all_test_loss_den_val = 0
    all_test_loss_conn_val = 0
    with torch.no_grad():
        for g in test_gs:
            g = model(g)
            loss, loss_den_val, loss_conn_val = model.compute_loss(g)
            all_test_loss_den_val += loss_den_val
            all_test_loss_conn_val += loss_conn_val
    print(
        "Testing, epoch: %d, loss_den: %.6f, loss_conn: %.6f"
        % (epoch, all_test_loss_den_val, all_test_loss_conn_val)
    )
    if all_test_loss_conn_val + all_test_loss_den_val < best_loss:
        best_loss = all_test_loss_conn_val + all_test_loss_den_val
        print("New best epoch", epoch)
        torch.save(model.state_dict(), args.model_filename + "_best")
    torch.save(model.state_dict(), args.model_filename)

torch.save(model.state_dict(), args.model_filename)
