import numpy as np
import torch
import wandb

from pprint import pprint
from inscd import listener
from inscd.datahub import DataHub
from inscd.models.static.neural import NCDM
from inscd.models.static.graph import ULCDF
from inscd.models.static.neural import KANCD
from inscd.models.static.classic import MIRT
from inscd.models.static.graph import RCD, ORCDF
from inscd.models.static.neural import CDMFKC
from inscd.models.static.neural import KSCD
from inscd.plot.tsne import plot_tsne, plot_tsne_pure, plot_mas
from inscd.plot.mad_map import plot_mnd_map

# wandb.init(
#     project="test inscd"
# )
import CAT

listener.update(print)
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
datahub = DataHub(f"datasets/EdNet-1")
print("Number of response logs {}".format(len(datahub)))
train_student_num = datahub.group_split(source="total", to=["train", "other"], seed=1, slice_out=0.7)
valid_student_num = datahub.group_split(source="other", to=["valid", "test"], seed=1, slice_out=1 - 1 / 3)
datahub.get_CAT_Dataset('test')
model = ORCDF(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
model.build(latent_dim=32, device='cuda:0', if_type='ncd',
            gcn_layers=1, keep_prob=1.0,
            dtype=torch.float64, ssl_weight=3e-3, ssl_temp=0.8,
            flip_ratio=0.05, mode='cl')
model.train(datahub, valid_metrics=['auc', 'ap'], lr=1e-3, batch_size=4096,
            epoch=4, weight_decay=0)
update_config = {
    'lr': 1e-3,
    'batch_size': 4096,
    'weight_decay': 0,
    'epoch': 10
}
config_dict = {}
config_dict['strategy'] = 'Random'
import random

if config_dict['strategy'] == 'BECAT':
    strategies = [CAT.strategy.BECATstrategy()]
elif config_dict['strategy'] == 'Random':
    strategies = [CAT.strategy.RandomStrategy()]
elif config_dict['strategy'] == 'MAAT':
    strategies = [CAT.strategy.MAATStrategy()]
else:
    raise ValueError
for strategy in strategies:
    datahub['CAT'].reset()
    print('-----------')
    print(f'start adaptive testing with {strategy.name} strategy')
    print(f'Iteration 0')
    results = model.cat_evaluate(datahub)
    for name, value in results.items():
        print(f'{name}:{value}')
    # if model.name != 'orcdf':
    #     student_list = range(test_data.num_students)
    # else:
    student_list = np.unique(datahub['test'][:, 0]).astype(int).tolist()
    test_length = 15
    S_sel = {}
    for sid in student_list:
        key = sid
        S_sel[key] = []
    selected_questions = {}
    select_data = []
    for it in range(1, test_length + 1):
        print(f'Iteration {it}')
        # select question
        if it == 1 and strategy.name == 'BECAT':
            for sid in student_list:
                untested_questions = np.array(list(datahub['CAT'].untested[sid]))
                random_index = random.randint(0, len(untested_questions) - 1)
                selected_questions[sid] = untested_questions[random_index]
                S_sel[sid].append(untested_questions[random_index])

        elif strategy.name == 'BECAT':
            selected_questions = strategy.adaptest_select(model, datahub, student_list=student_list, S_sel_dict=S_sel)
            for sid in student_list:
                S_sel[sid].append(selected_questions[sid])
        elif strategy.name == 'MAAT':
            selected_questions = strategy.adaptest_select(model, datahub, update_config=update_config, student_list=student_list)
        else:
            selected_questions = strategy.adaptest_select(model, datahub, student_list=student_list)

        for student, question in selected_questions.items():
            datahub['CAT'].apply_selection(student, question)
            select_data.append([student, question, datahub['CAT'].data[student][question]])

        if model.name == 'orcdf':
            model.update_graph(np.vstack([datahub['train'], np.array(select_data)]), datahub.q_matrix)
            model.cat_adaptest_update(datahub, update_config)
        else:
            model.update_graph(np.vstack([datahub['train'], np.array(select_data)]), datahub.q_matrix)
            model.cat_adaptest_update(datahub, update_config)
        # evaluate models
        results = model.cat_evaluate(datahub)
        # log results
        # wb.log(results)
        for name, value in results.items():
            print(f'{name}:{value}')
