import numpy as np
import torch
import wandb
import warnings
warnings.filterwarnings('ignore')
from pprint import pprint
from inscd import listener
from inscd.datahub import DataHub
from inscd.models.static.neural import NCDM
from inscd.models.static.graph import ULCDF
from inscd.models.static.neural import KANCD
from inscd.models.static.classic import MIRT
from inscd.models.static.neural import KSCD
from inscd.models.static.graph import RCD
from inscd.models.static.graph import LIGHTGCN
from inscd.models.static.neural import DCD
from inscd.models.static.graph import SCD
from inscd.models.static.graph import GCMC
from inscd.models.static.graph import RGCN
from inscd.models.static.graph import ORCDF

# wandb.init(
#     project="test inscd"
# )

listener.update(print)
seed = 2
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

datahub = DataHub("datasets/Assist17")
datahub.random_split(source="total", to=["train", "test"], seed=seed)
#
# ncdm = NCDM(datahub.student_num, datahub.exercise_num, datahub.knowledge_num, save_flag=True)
# ncdm.build(device='cuda:0')
# ncdm.train(datahub, "train", "test", valid_metrics=['auc', 'acc'], batch_size=4096, lr=4e-3, epoch=5, weight_decay=0)
# mas_ncdm = ncdm.mastery_list[1]
#
# orcdf = ORCDF(datahub.student_num, datahub.exercise_num, datahub.knowledge_num, save_flag=True)
# orcdf.build(device='cuda:0', if_type='ncd', ssl_temp=3, ssl_weight=1e-3)
# orcdf.train(datahub, "train", "test", valid_metrics=['auc', 'acc', 'mad', 'doa'], batch_size=4096, lr=4e-3, weight_decay=0, epoch=4)
#
# mas_orcdf = orcdf.mastery_list[3]
#
# print(np.sum(np.abs(mas_orcdf - mas_ncdm)) / datahub.student_num)
# print(np.sum(np.abs(mas_orcdf - mas_ncdm)) / datahub.student_num / datahub.knowledge_num)


# gcmc = GCMC(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# gcmc.build(device='cuda:1', dtype=torch.float64, gcn_layers=1, if_type='ncd')
# gcmc.train(datahub, f"train", f"test", valid_metrics=['auc', 'acc'], batch_size=4096, lr=1e-3, weight_decay=5e-6, epoch=5)

rgcn = RGCN(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
rgcn.build(device='cuda:0', dtype=torch.float32, gcn_layers=2, if_type='ncd')
rgcn.train(datahub, f"train", f"test", valid_metrics=['auc', 'acc'], batch_size=4096, lr=5e-4, weight_decay=0, epoch=20)





# ncdm = NCDM(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# ncdm.build(device='cuda:0')
# ncdm.train(datahub, "train", "test", valid_metrics=['auc', 'acc'], batch_size=256, lr=4e-3)

# dcd = DCD(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# dcd.build(device="cuda:0", dtype=torch.float64)
# dcd.train(datahub, "train", "test", valid_metrics=['auc', 'acc', 'ap', 'rmse', 'ap', 'f1'], batch_size=4096, lr=5e-4, epoch=400)
# for index in range(1, 6):
# kancd = KANCD(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# kancd.build(32, device='cuda:2')
# kancd.train(datahub, f"train", f"test", valid_metrics=['auc', 'acc'], batch_size=256, lr=4e-3, weight_decay=0, epoch=15)

# scd = SCD(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# scd.build(device='cuda:0', dtype=torch.float64)
# scd.train(datahub, f"train", f"test", valid_metrics=['auc', 'acc'], batch_size=256, lr=4e-3, weight_decay=5e-6, epoch=20)

# ulcdf = ULCDF(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
# ulcdf.build(latent_dim=32, device='cuda:1', if_type='dp-linear', gcn_layers=1, dtype=torch.float64,
#             activation='ELU', keep_prob=0.5)
# ulcdf.train(datahub, f"train", f"test", valid_metrics=['auc', 'acc'], batch_size=4096, lr=5e-4, weight_decay=5e-6,
#             epoch=20)
