#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2025/9/24 20:10
# @Author  : hb
# @File    : train.py.py
import datetime
import json
import os
import argparse
import shutil
import time
from typing import Dict

import pandas as pd

from data_trans import  expand_kf_data
from data_load import split_generator
from torch.utils.data import DataLoader
import torch
import numpy as np
import random

from dkt import DKT
from akt import AKT
from dkvmn import DKVMN
from rekt import ReKT
from simplekt import SimpleKT
from ukt import UKT
from fluckt import FlucKT
from leftokt import LEFOKT_AKT


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def logs2str(logs: Dict) -> str:
    logstrs = []
    for k, v in logs.items():
        logstrs.append(f'{k}:{v:.4f}')
    return ",".join(logstrs)


def to_device(value, device):
    if isinstance(value, torch.Tensor):
        return value.to(device)
    elif isinstance(value, (tuple, list)):
        vv = [to_device(_, device) for _ in value]
        return vv
    elif isinstance(value, dict):
        newValue = {}
        for k, v in value.items():
            newValue[k] = to_device(v, device)
        return newValue
    else:
        raise ValueError("nonsupport type!")

def create_model(model_name,model_configs):
    if model_name == 'dkt':
        kt_model = DKT(**model_configs)
    elif model_name == 'dkvmn':
        kt_model =DKVMN(**model_configs)
    elif model_name == 'akt':
        kt_model =AKT(**model_configs)
    elif model_name == 'ukt':
        kt_model =UKT(**model_configs)
    elif model_name == 'simplekt':
        kt_model =SimpleKT(**model_configs)
    elif model_name == 'rekt':
        kt_model =ReKT(**model_configs)
    elif model_name == 'leftokt':
        kt_model =LEFOKT_AKT(**model_configs)
    elif model_name == 'fluckt':
        kt_model =FlucKT(**model_configs)
    else:
        raise ValueError("unknown model name!")
    return kt_model
def _should_save(model, val_logs, minitor, last_value, laset_epoch, epoch, weight_path):
    if minitor in val_logs.keys() and val_logs[minitor] - last_value > 0.0001:
        torch.save(model.state_dict(), weight_path)
        return val_logs[minitor], epoch
    else:
        return last_value, laset_epoch

def train(model, train_data, test_data, val_data=None, log_dir=None, max_epochs=1,
        device="cpu", patience=10, valid_interval=1, **kwargs):
    if len(kwargs) > 0:
        print(f"unused params for train:{kwargs}")
    model.to(device)
    model.reset_state()
    os.makedirs(os.path.join(log_dir, "weights"), exist_ok=True)
    weight_path = os.path.join(log_dir, "weights", 'model.pth')
    csv_log_path = os.path.join(log_dir, "train.csv")
    test_log_path = os.path.join(log_dir, "test_result.csv")
    train_logs_data = []
    if val_data is None:
        val_data = test_data
    total_batch = None
    progressbar_str = ""
    max_auc = 0
    best_epoch = -1
    width = 30
    for epoch in range(max_epochs):
        model.train()
        model.reset_state()
        print(f"Epoch {epoch + 1}/{max_epochs}")
        last_train_out_str = ""
        enpoch_start = time.time()
        batch = 0
        logs = {}
        for batch, data in enumerate(train_data):
            step_start = time.time()
            data = to_device(data, device)
            tmp_logs = model.train_step(data)
            logs = tmp_logs
            now = time.time()
            time_step = round((now - step_start) * 1000)
            progressbar_str = ""
            if total_batch is None:
                bar = f"{batch + 1:7d}/unknown,[{'=' * width}]"
                progressbar_str += bar
            else:
                numdigits = int(np.log10(total_batch)) + 1
                prog_width = round((batch + 1) / total_batch * width)
                bar = f"{batch + 1:{numdigits}d}/{total_batch},[{'=' * prog_width}{'.' * round(width - prog_width)}]"
                progressbar_str += bar
            if total_batch is not None and batch + 1 == total_batch:
                progressbar_str += f"- {round(now - enpoch_start):^3d}s :{time_step}ms/step,"
            else:
                progressbar_str += f"- ETA :{time_step}ms/step,"
            train_out_str_add = progressbar_str + logs2str(logs)
            print(len(last_train_out_str) * "\b" + "\r" + train_out_str_add, end="")
            last_train_out_str = train_out_str_add
        tr_logs = model.compute_metrics()
        tr_logs = {
            "train_" + name: val for name, val in tr_logs.items()
        }
        trlogs_str = logs2str(tr_logs)
        print(len(last_train_out_str) * "\b" + "\r" + progressbar_str + trlogs_str, end="")
        if total_batch is None:
            total_batch = batch + 1
        if epoch % valid_interval == 0:
            val_logs = evaluate(model,val_data, device=device)
            if "auc" in val_logs.keys() and val_logs["auc"] - max_auc > 0.0001:
                max_auc = val_logs["auc"]
                best_epoch = epoch
                torch.save(model.state_dict(), weight_path)
            val_logs = {
                "val_" + name: val for name, val in val_logs.items()
            }
        print("\r" + progressbar_str + trlogs_str + "," + logs2str(val_logs))
        train_log_record = {"enpoch": epoch}
        train_log_record.update(tr_logs)
        train_log_record.update(val_logs)
        train_logs_data.append(train_log_record)
        if 0 < patience * valid_interval < (epoch - best_epoch):
            break

    pd.DataFrame(train_logs_data).to_csv(csv_log_path, index=False, encoding='utf-8',
                                         float_format='%.5f')
    if os.path.exists(weight_path):
        stat_dict = torch.load(weight_path)
        model.load_state_dict(stat_dict)
    test_logs = evaluate(model,test_data, device=device)
    test_logs1 = {
        "test_" + name: val for name, val in test_logs.items()
    }
    print(logs2str(test_logs1))
    pd.DataFrame([test_logs]).to_csv(test_log_path, index=False, encoding='utf-8',
                                     float_format='%.5f')
    return test_logs

def evaluate(self, data_set, device="cpu"):
    model.to(device)
    model.reset_state()
    model.eval()
    with torch.no_grad():
        for step, data in enumerate(data_set):
            data = to_device(data, device)
            model.test_step(data)
    logs = model.compute_metrics()
    return logs
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', default='dkt', help='Dataset name: assist0910,assist2012')
parser.add_argument('--data_name', default='assist2009', help='Dataset name: assist0910,assist2012')
parser.add_argument('--logs_base', default='../result', help='logs dir path')
parser.add_argument('--data_base', default='../dataset', help='data base  path')
parser.add_argument('--mode', type=str, default='KC', choices=['Q', 'KC', "Ours"])
parser.add_argument('--max_len', type=int, default=100, help='The max length of sequence')
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
parser.add_argument('--batch_size', type=int, default=32, help='Batch Size')
parser.add_argument("--emb_size", default=64, type=int, help="embedding size of model")
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout of representation')
parser.add_argument('--max_epochs', type=int, default=100, help='Max number of epochs for training')  ## 500
parser.add_argument('--optimizer', type=str, default='Adam', choices=['SGD', 'Adam'])
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=0.00001, help='l2')
parser.add_argument('--valid_interval', type=int, default=1, help='the number of epoch to eval')
parser.add_argument('--patience', type=int, default=5, help='the number of epoch to wait before early stop')
parser.add_argument('--seed', type=int, default=100, help='the number of random seed')
parser.add_argument('--folds', type=int, default=5, help='cross evaluate folds')
parser.add_argument('--fast_run', action='store_true', default=False,
                    help='fast run in sample data')
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == "__main__":
    args = parser.parse_args()
    parser.parse_known_args()
    setup_seed(args.seed)
    # if args.mode == "cq" and args.model_name in ["dkt", "dkvmn"]:
    #     exit()
    if not torch.cuda.is_available():
        if args.device == "cuda":
            print("warning:cuda is not enable!!!!!!!!!!!!!!!!!!")
        args.device = "cpu"
    device = args.device
    dataset_dir = os.path.join(os.path.abspath(args.data_base), args.data_name)
    _info_dict = {}
    with open(os.path.join(dataset_dir, "info.json"), "r", encoding="utf-8") as f:
        _info_dict.update(json.load(f))
    model_configs = vars(args)
    model_configs.update(dict(data_name=_info_dict["name"], skill_num=_info_dict["skill_num"],
                              problem_num=_info_dict["problem_num"], group_num=_info_dict["group_num"]))
    log_path = f'{args.data_name}_{args.model_name}_{args.mode.upper()}_{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}'
    log_dir = os.path.join(os.path.abspath(args.logs_base), log_path)
    try:
        os.makedirs(log_dir, exist_ok=True)
        with open(os.path.join(log_dir, "run_config.json"), "w", encoding="utf-8") as f:
            json.dump(model_configs, f, indent=4)
        if args.folds > 0:
            sub_data_dir = [os.path.join(dataset_dir, f'k{i}') for i in range(args.folds)]
            if any([not os.path.exists(f) for f in sub_data_dir]):
                expand_kf_data(dataset_dir,max_len=-1)
            # for d in sub_data_dir:
            #     filter_datas(max_len, d)
            folds_params = [(os.path.join(dataset_dir, f'k{k}'), os.path.join(log_dir, f'k{k}')) for k in range(args.folds)]
        else:
            folds_params = [(dataset_dir, log_dir)]
        all_results = []
        for datas_dir_s, logs_dir_s in folds_params:
            model = create_model(args.model_name,model_configs)
            model.compile_model(optimizer=args.optimizer, lr=args.lr, weight_decay=args.weight_decay)
            feature_names, label_names = model.inputs_specs
            train_data, val_data, test_data = split_generator(datas_dir_s, model_configs["skill_num"],
                                                              model_configs["problem_num"],model_configs["group_num"], None, feature_names,
                                                              label_names,sample_num=256 if args.fast_run else -1,
                                                              max_len=args.max_len)
            train_dataloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
            val_dataloader = DataLoader(val_data, batch_size=args.batch_size)
            if args.folds > 0:
                test_dataloader = val_dataloader
            else:
                test_dataloader = DataLoader(test_data, batch_size=args.batch_size)
            test_logs = train(model,train_dataloader, test_dataloader, val_dataloader, logs_dir_s, **model_configs)
            all_results.append(test_logs)
            del model
        all_result_df = pd.DataFrame(all_results)
        mean_result = all_result_df.mean()
        if len(all_result_df) > 1 and "auc" in all_result_df.columns:
            mean_result["auc_p"] = all_result_df["auc"].std()
        else:
            mean_result["auc_p"] = -1
        if len(all_result_df) > 1 and "acc" in all_result_df.columns:
            mean_result["acc_p"] = all_result_df["acc"].std()
        else:
            mean_result["acc_p"] = -1
        if len(all_result_df) > 1 and "rmse" in all_result_df.columns:
            mean_result["rmse_p"] = all_result_df["rmse"].std()
        else:
            mean_result["rmse_p"] = -1
        pd.DataFrame([mean_result]).to_csv(os.path.join(log_dir, "test_result.csv"), index=False, encoding='utf-8',
                                           float_format='%.5f')
        final_result = mean_result.to_dict()
        print(final_result)
    except Exception as e:
        shutil.rmtree(log_dir,ignore_errors=True)