import time
from copy import deepcopy
import argparse
import torch.optim as optim
import torch.nn.functional as F
import hydra
import logging
import datetime
from utils import *

from data.load_dataset import load_data

from dhg.random import set_seed
from model.hds_ode import hds_ode
from dhg.metrics import HypergraphVertexClassificationEvaluator as Evaluator


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default="Cora-CA",
                    help='Dataset to use.')
parser.add_argument('--model_name', type=str, default="hds_ode",
                    help='model to use.')
parser.add_argument('--lr', type=float, default=1e-2,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--repeat', type=int, default=5,
                    help='Number of repeat times to test.')
parser.add_argument('--epochs', type=int, default=200,
                    help='Number of epochs to train.')
parser.add_argument('--layer_num', type=int, default=40,
                    help='Number of layers.')
parser.add_argument('--step', type=int, default=20,
                    help='Steps of diffusion.')
parser.add_argument('--alpha_v', type=float, default=0.05,
                    help='alpha of vertices in diffusion.')
parser.add_argument('--alpha_e', type=float, default=0.9,
                    help='alpha of hyperedge in diffusion.')
parser.add_argument('--dropout', type=float, default=0.15,
                    help='Dropout rate')
parser.add_argument('--split', type=str, default='random',
                    help='type of split')
parser.add_argument('--num_per_class', type=int, default=10,
                    help='Number of training vertex per class.')
parser.add_argument('--num_development', type=int, default=1500,
                    help='Number of training+val vertices.')
parser.add_argument('--test_ind_ratio', type=float, default=0.2,
                    help='Ratio of inductive vertices of test vertices.')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training.')
parser.add_argument('--seed', type=int, default=2022, help='Random seed.')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

log = logging.getLogger(__name__)


def load_model(model_name, dim_features, num_classes):
    if model_name == 'hds_ode':
        return hds_ode(dim_features, num_classes,
                    args.layer_num, args.step, alpha_v=args.alpha_v, alpha_e=args.alpha_e, drop_rate=args.dropout)
    else:
        print("model doesn't exist")


def train(net, X, A, lbls, train_idx, optimizer, epoch):
    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")
    return loss.item()


@torch.no_grad()
def infer(net, X, A, lbls, idx, evaluator, test=False):
    net.eval()

    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]
    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)
    return res


@torch.no_grad()
def test(net, X_t, A_t, lbls_t, mask_t, X, A, lbls, mask, prod_mask, evaluator):
    net.eval()
    # transductive
    outs_t = net(X_t, A_t)
    res_t = evaluator.test(lbls_t[mask_t], outs_t[mask_t])
    # inductive
    outs = net(X, A)
    res_i = evaluator.test(lbls[mask], outs[mask])
    # product
    outs = net(X, A)
    res_p = evaluator.test(lbls[prod_mask], outs[prod_mask])
    res = {}
    for k, v in res_p.items():
        res[f"prod_{k}"] = v
    for k, v in res_i.items():
        res[f"ind_{k}"] = v
    for k, v in res_t.items():
        res[f"trans_{k}"] = v
    return res


@hydra.main(version_base=None)
def main(cfg):
    log.info(args)
    start_time = datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d_%H:%M:%S')
    set_seed(args.seed)
    device = torch.device("cuda") if args.cuda else torch.device("cpu")
    evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])

    res_list = MultiExpMetric()
    for i in range(args.repeat):
        X, HG, lbl, dim_features, num_classes, train_mask, val_mask, test_mask = load_data(args.dataset, args.split,
                                                                                           args.num_per_class,
                                                                                           args.num_development)
        obs_idx, obs_train_mask, obs_val_mask, obs_test_mask, test_ind_mask = product_split(
            train_mask, val_mask, test_mask, args.test_ind_ratio
        )
        HG_t = sub_hypergraph(HG, obs_idx)
        X_t, lbl_t = X[obs_idx], lbl[obs_idx]
        model = load_model(args.model_name, dim_features, num_classes)
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        X, lbl = X.to(device), lbl.to(device)
        HG = HG.to(device)
        X_t, lbl_t = X_t.to(device), lbl_t.to(device)
        HG_t = HG_t.to(device)
        model = model.to(device)

        best_state = None
        best_epoch, best_val = 0, 0
        for epoch in range(args.epochs):
            # train
            train(model, X_t, HG_t, lbl_t, obs_train_mask, optimizer, epoch)
            # validation
            if epoch % 1 == 0:
                with torch.no_grad():
                    val_res = infer(model, X_t, HG_t, lbl_t, obs_val_mask, evaluator)
                if val_res > best_val:
                    log.info(f"update best: {val_res:.5f}")
                    best_epoch = epoch
                    best_val = val_res
                    best_state = deepcopy(model.state_dict())
        log.info("\ntrain finished!")
        log.info(f"best val: {best_val:.5f}")
        # test
        log.info("test...")
        model.load_state_dict(best_state)
        res = test(model, X_t, HG_t, lbl_t, obs_test_mask, X, HG, lbl, test_ind_mask, test_mask, evaluator)
        log.info(f"final result: epoch: {best_epoch}")
        log.info(res)
        log.info('#' * 30)
        res_list.update(res)

    log.info(f"{res_list}")


if __name__ == '__main__':
    main()
