import torch
import numpy as np
import argparse
import sys
import os
import wandb as wb
from pprint import pprint
from inscd import listener
from inscd.datahub import DataHub
# from inscd.models.static.graph import ULCDF
from inscd.models.static.graph import ORCDF
from inscd.models.static.neural import KANCD
from inscd.models.static.neural import NCDM
from inscd.models.static.classic import MIRT
from inscd.models.static.neural import KSCD
from inscd.models.static.graph import LIGHTGCN
from inscd.models.static.graph import RCD
import CAT as CAT


def set_seeds(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


# os.environ["WANDB_MODE"] = "offline"

parser = argparse.ArgumentParser()
parser.add_argument('--method', default='orcdf', type=str,
                    help='A Lightweight Graph-based Cognitive Diagnosis Framework', required=True)
parser.add_argument('--data_type', default='junyi', type=str, help='benchmark', required=True)
parser.add_argument('--strategy', default='BECAT', type=str, required=True)
parser.add_argument('--seed', default=0, type=int, help='seed for exp', required=True)
parser.add_argument('--dtype', default=torch.float64, help='dtype of tensor')
parser.add_argument('--device', default='cuda', type=str, help='device for exp')
parser.add_argument('--gcn_layers', type=int, help='numbers of gcn layers', default=3)
parser.add_argument('--latent_dim', type=int, help='dimension of hidden layer', default=32)
parser.add_argument('--batch_size', type=int, help='batch size of benchmark', default=32)
parser.add_argument('--exp_type', help='experiment type', default='cat')
parser.add_argument('--pre_lr', type=float, help='learning rate', default=1e-3)
parser.add_argument('--pre_epoch', type=int, help='learning rate', default=1)
parser.add_argument('--ada_lr', type=float, help='learning rate', default=1e-3)
parser.add_argument('--ada_epoch', type=int, help='learning rate', default=10)
parser.add_argument('--if_type', type=str, help='interaction type')
parser.add_argument('--keep_prob', type=float, default=1.0, help='edge drop probability')
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--ssl_temp', type=float, default=3)
parser.add_argument('--ssl_weight', type=float, default=3e-3)
parser.add_argument('--flip_ratio', type=float, default=0.05)
parser.add_argument('--mode', type=str, default='')
config_dict = vars(parser.parse_args())

method_name = config_dict['method']
name = f"{method_name}-{config_dict['data_type']}-seed{config_dict['seed']}"
tags = [config_dict['method'], config_dict['data_type'], str(config_dict['seed'])]
config_dict['name'] = name
method = config_dict['method']
datatype = config_dict['data_type']

if config_dict.get('if_type', None) is None:
    config_dict['if_type'] = config_dict['method']

if 'orcdf' in method:
    if config_dict.get('weight_reg') is None:
        config_dict['weight_reg'] = 0.05
pprint(config_dict)


run = wb.init(project="orcdf", name=name,
              tags=tags,
              config=config_dict)
config_dict['id'] = run.id


def main(config_dict):
    set_seeds(config_dict['seed'])
    datahub = DataHub(f"datasets/{config_dict['data_type']}")
    print("Number of response logs {}".format(len(datahub)))
    train_student_num = datahub.group_split(source="total", to=["train", "other"], seed=1, slice_out=0.7)
    valid_student_num = datahub.group_split(source="other", to=["valid", "test"], seed=1, slice_out=1 - 1 / 3)
    concept_map = datahub.get_concept_map()

    def renumber_student(data):
        unique_values, indices = np.unique(data[:, 0], return_inverse=True)
        remapping = {original_value: new_number for new_number, original_value in enumerate(unique_values)}
        data[:, 0] = np.vectorize(remapping.get)(data[:, 0])
        return data
    test_length = 20


    train_data = CAT.dataset.TrainDataset(renumber_student(datahub['train']).astype(int), concept_map,
                                          train_student_num,
                                          datahub.exercise_num,
                                          datahub.knowledge_num)

    method = config_dict['method']
    if method == 'orcdf':
        from inscd.models.static.graph.orcdf import orcdf
        model = orcdf(datahub.student_num, datahub.exercise_num, datahub.knowledge_num)
        if config_dict['if_type'] == 'kancd' or config_dict['if_type'] == 'mirt' or config_dict['if_type'] == 'kscd' or config_dict[
            'if_type'] == 'irt':
            config_dict['mode'] = 'tfcl'
        else:
            config_dict['mode'] = 'cl'
        model.build(latent_dim=config_dict['latent_dim'], device=config_dict['device'], if_type=config_dict['if_type'],
                    gcn_layers=config_dict['gcn_layers'], keep_prob=config_dict['keep_prob'],
                    dtype=config_dict['dtype'], ssl_weight=config_dict['ssl_weight'], ssl_temp=config_dict['ssl_temp'],
                    flip_ratio=config_dict['flip_ratio'], mode=config_dict['mode'])
        model.train(datahub, valid_metrics=['auc', 'ap'], lr=config_dict['pre_lr'], batch_size=config_dict['batch_size'], epoch=config_dict['pre_epoch'])
    elif method == 'irt':
        config_dict['latent_dim'] = 10
        config_dict['pre_lr'] = 2e-3
        config_dict['pre_epoch'] = 10
        config_dict['ada_lr'] = config_dict['pre_lr']
        config_dict['ada_epoch'] = config_dict['ada_epoch']
        model = CAT.model.IRTModel(**config_dict)
        model.init_model(train_data)
        model.train(train_data, log_step=10)


    elif method == 'ncd':
        config_dict['prednet_len1'] = 512
        config_dict['prednet_len2'] = 256
        config_dict['pre_epoch'] = 10
        config_dict['pre_lr'] = 3e-3
        config_dict['ada_lr'] = config_dict['pre_lr']
        config_dict['ada_epoch'] = 20
        model = CAT.model.NCDModel(**config_dict)
        model.init_model(train_data)
        model.train(train_data)

    if method == 'orcdf':
        test_data = CAT.dataset.AdapTestDataset(datahub['test'].astype(int), concept_map,
                                                datahub.student_num - train_student_num - valid_student_num,
                                                datahub.exercise_num,
                                                datahub.knowledge_num)
    else:
        test_data = CAT.dataset.AdapTestDataset(renumber_student(datahub['test']).astype(int), concept_map,
                                                datahub.student_num - train_student_num - valid_student_num,
                                                datahub.exercise_num,
                                                datahub.knowledge_num)
    import random
    if config_dict['strategy'] == 'BECAT':
        strategies = [CAT.strategy.BECATstrategy()]
    elif config_dict['strategy'] == 'Random':
        strategies = [CAT.strategy.RandomStrategy()]
    elif config_dict['strategy'] == 'MAAT':
        strategies = [CAT.strategy.MAATStrategy()]
    else:
        raise ValueError
    for strategy in strategies:
        test_data.reset()
        print('-----------')
        print(f'start adaptive testing with {strategy.name} strategy')
        print(f'Iteration 0')
        results = model.evaluate(test_data)
        for name, value in results.items():
            print(f'{name}:{value}')
        if model.name != 'orcdf':
            student_list = range(test_data.num_students)
        else:
            student_list = np.unique(datahub['test'][:, 0]).astype(int).tolist()

        test_data.student_list = student_list

        S_sel = {}
        for sid in student_list:
            key = sid
            S_sel[key] = []
        selected_questions = {}
        select_data = []
        for it in range(1, test_length + 1):
            print(f'Iteration {it}')
            # select question
            if it == 1 and strategy.name == 'BECAT Strategy':
                for sid in student_list:
                    untested_questions = np.array(list(test_data.untested[sid]))
                    random_index = random.randint(0, len(untested_questions) - 1)
                    selected_questions[sid] = untested_questions[random_index]
                    S_sel[sid].append(untested_questions[random_index])
            elif strategy.name == 'BECAT Strategy':
                selected_questions = strategy.adaptest_select(model, test_data, S_sel)
                for sid in student_list:
                    S_sel[sid].append(selected_questions[sid])
            elif strategy.name == 'Model Agnostic Adaptive Testing':
                config_tmp = {'lr': config_dict['ada_lr'],
                               'batch_size': config_dict['batch_size'],
                               'epoch': config_dict['ada_epoch']}
                selected_questions = strategy.adaptest_select(model, test_data, config_tmp)
            else:
                selected_questions = strategy.adaptest_select(model, test_data)

            for student, question in selected_questions.items():
                test_data.apply_selection(student, question)
                select_data.append([student, question, test_data.data[student][question]])

            if model.name == 'orcdf':
                model.update_graph(np.vstack([datahub['train'], np.array(select_data)]), datahub.q_matrix)
                model.adaptest_update(test_data, lr=config_dict['ada_lr'], batch_size=config_dict['batch_size'], epoch=config_dict['ada_epoch'])
            else:
                model.adaptest_update(test_data)
            # evaluate models
            results = model.evaluate(test_data)
            # log results
            wb.log(results)
            for name, value in results.items():
                print(f'{name}:{value}')


if __name__ == '__main__':
    sys.exit(main(config_dict))
