import numpy as np
import torch
import wandb

from pprint import pprint
from inscd import listener
from inscd.datahub import DataHub
from inscd.models.static.neural import NCDM
from inscd.models.static.graph import ULCDF
from inscd.models.static.neural import KANCD
from inscd.models.static.classic import MIRT
from inscd.models.static.neural import KSCD
from inscd.models.static.graph import RCD
from inscd.models.static.graph import LIGHTGCN
from inscd.models.static.neural import DCD
from inscd.models.static.graph import SCD
from inscd.models.static.graph import ICDM

# wandb.init(
#     project="test inscd"
# )

listener.update(print)
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

datahub = DataHub("datasets/FrcSub")
# folds = datahub.k_fold_split(5, seed)
datahub.random_split(source="total", to=["train", "test"], seed=seed)
print("Number of response logs {}".format(len(datahub)))
print(datahub.q_matrix.shape)
# dcd = DCD(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# dcd.build(device="cuda:0", dtype=torch.float64)
# dcd.train(datahub, "train", "test", valid_metrics=['auc', 'acc', 'ap', 'rmse', 'ap', 'f1'], batch_size=4096, lr=5e-4, epoch=400)
# for index in range(1, 6):
# icdm = ICDM(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# icdm.build(latent_dim=32, device='cuda:0', dtype=torch.float64, if_type='glif', khop=1, gcn_layers=3)
# icdm.train(datahub, f"train", f"test", valid_metrics=['auc', 'acc'], batch_size=32, lr=4e-3, weight_decay=5e-6, epoch=10)


# kancd = KANCD(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# kancd.build(32, device='cuda:2')
# kancd.train(datahub, f"train", f"test", valid_metrics=['auc', 'acc'], batch_size=32, lr=4e-3, weight_decay=0, epoch=15)

# scd = SCD(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# scd.build(device='cuda:1', dtype=torch.float64)
# scd.train(datahub, f"train", f"test", valid_metrics=['auc', 'acc'], batch_size=4096, lr=4e-3, weight_decay=5e-6,
#           epoch=20)

ulcdf = ULCDF(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
ulcdf.build(latent_dim=32, device='cuda:1', if_type='dp-linear', gcn_layers=3, dtype=torch.float64,
            activation='ELU', keep_prob=1)
ulcdf.train(datahub, f"train", f"test", valid_metrics=['auc', 'acc', 'doa'], batch_size=256, lr=4e-3, weight_decay=5e-6, epoch=20)
#
# rcd = RCD(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# rcd.build(device='cuda:1', if_type='rcd', dtype=torch.float64)
# rcd.train(datahub, "train", "test", valid_metrics=['auc', 'acc', 'ap', 'doa', 'mad'], batch_size=256, lr=3e-3, weight_decay=0)

# lightgcn = LIGHTGCN(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# lightgcn.build(device='cuda:0', if_type='ncd', gcn_layers=3, dtype=torch.float64, keep_prob=0.9)
# lightgcn.train(datahub, "train", "test", valid_metrics=['auc', 'acc', 'ap', 'rmse', 'ap', 'f1', 'doa'], batch_size=256, lr=3e-3, weight_decay=0)
# ncdm = NCDM(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# ncdm.build()
# ncdm.train(datahub, "train", "test", valid_metrics=['auc', 'ap', 'doa'], batch_size=32)
# test_results = ncdm.score(datahub, "test", metrics=['auc', 'ap', 'doa'])
# pprint(test_results)

# ncdm = NCDM(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# ncdm.build()
# ncdm.train(datahub, "train", "test", valid_metrics=['auc', 'ap', 'doa'], batch_size=32)

# mirt = MIRT(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# mirt.build(latent_dim=16, device='cuda:0', if_type='sum')
# mirt.train(datahub, "train", "test", valid_metrics=['auc', 'ap'], batch_size=256, lr=0.01, weight_decay=0, epoch=20)

# kscd = KSCD(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# kscd.build(latent_dim=20, device='cuda:0', dtype=torch.float64)
# kscd.train(datahub, "train", "test", valid_metrics=['auc', 'ap', 'doa'], batch_size=256, weight_decay=0)
