import torch
from param import args
from utils import set_seed, Maximizer
from data import load_data, create_seqs, CrossValidation, init_qst_clr, format_seqs
from experiment import experiment, comp_entropy, sinkhorn
import numpy as np
import time
import os

if __name__ == '__main__':
    set_seed(args.seed)
    device = torch.device(args.device)
    path = '%s/%s' % (args.data_path, args.data_name)
    qst_num, usr_num, skl_num, qst_skl, records = load_data(path)
    seqs = create_seqs(usr_num, records)
    
    qst_clr = init_qst_clr(qst_num, args.clr_num)  
    best_qst_clr = qst_clr
    
    fin_auc_max = Maximizer()
    fin_res = None
    
    
    if args.save_each_clr != None:
        os.makedirs(args.save_each_clr, exist_ok = True)
        with open('%s/qst_clr_0.csv' % (args.save_each_clr, ), 'w', newline = '') as file:
            for clr in qst_clr:
                file.write(f'{clr}\n')
    
    for i in range(args.iter_count):
        start = time.time()      
          
        entropy = torch.zeros((qst_num, args.clr_num), device = device)
        
        fold_num = 0
        max_res_list = []
        for trn_seqs, evl_seqs in CrossValidation(seqs, args.fold, args.seq_len):
            fold_num += 1
            if args.detail: 
                print('================I%dF%d===============' % (i + 1, fold_num))
            
            def print_detail(epoch, loss, trn_res, evl_res, max_res, time):
                (trn_auc, trn_acc, trn_rmse), (evl_auc, evl_acc, evl_rmse), (max_auc, max_acc, max_rmse) = trn_res, evl_res, max_res
                if args.detail:
                    print('  '.join((
                        'epoch: %-4d' % epoch,
                        'loss: %-.4f' % loss,
                        'trn_auc: %-.4f' % trn_auc,
                        'auc: %-.4f/%-.4f' % (evl_auc, max_auc),
                        'acc: %-.4f/%-.4f' % (evl_acc, max_acc),
                        'dur: %-.2fs' % time,
                    )))
                
            model, max_res = experiment(
                qst_num, 
                args.clr_num,
                qst_clr,
                args.hidden_dim, 
                args.dropout, 
                args.wd, 
                args.lr, 
                args.batch_size, 
                args.quit_epoch, 
                trn_seqs, 
                evl_seqs, 
                device, 
                evl_trn_seq = True, 
                each_round_fn = print_detail
            )
            
            max_res_list.append(max_res)
                
            ety = comp_entropy(model, trn_seqs, qst_num, args.clr_num, qst_clr, args.ce_batch_size, device)
            entropy += ety
        
        P = sinkhorn(entropy, args.lamb, args.itr_num)
        qst_clr = torch.argmax(P, dim = 1).cpu().numpy().tolist()

        if args.save_each_clr != None:
            with open('%s/qst_clr_%d.csv' % (args.save_each_clr, i + 1), 'w', newline = '') as file:
                for clr in qst_clr:
                    file.write(f'{clr}\n')
                    
        max_res = np.array(max_res_list)
        auc_mea, acc_mea, rmse_mea = np.mean(max_res, axis = 0)
        auc_std, acc_std, rmse_std = np.std(max_res, axis = 0)
        
        if fin_auc_max.join(auc_mea):
            fin_res = i, (auc_mea, auc_std), (acc_mea, acc_std), (rmse_mea, rmse_std)
            best_qst_clr = qst_clr
        
        dur = time.time() - start
        
        if args.detail: 
            print('---------------------------------')
        
        if args.detail or args.summary:
            print('  '.join((
                'iter_count: %-3d' % (i + 1),
                'auc: %-.4f ± %-.4f' % (auc_mea * 1, auc_std * 1),
                'acc: %-.4f ± %-.4f' % (acc_mea * 1, acc_std * 1),
                'rmse: %-.4f ± %-.4f' % (rmse_mea * 1, rmse_std * 1),
                'dur: %-.2fs' % dur,
            )))
        
        if args.save_sum != None:
            with open(args.save_sum, 'a', newline = '') as file:     
                content = '  '.join((
                    'iter_count: %-3d' % (i + 1),
                    'auc: %-.4f ± %-.4f' % (auc_mea * 1, auc_std * 1),
                    'acc: %-.4f ± %-.4f' % (acc_mea * 1, acc_std * 1),
                    'rmse: %-.4f ± %-.4f' % (rmse_mea * 1, rmse_std * 1),
                    'dur: %-.2fs' % dur,
                ))       
                file.write(f'{content}\n')
                
        if args.detail: 
            print('---------------------------------')
        
    if args.result:
        if args.detail or args.summary:
            print('---------------------------------')
        i, (auc, auc_std), (acc, acc_std), (rmse, rmse_std) = fin_res
        print('iter_count: %d' % (i + 1))
        print('auc: %-.4f ± %-.4f' % (auc, auc_std))
        print('acc: %-.4f ± %-.4f' % (acc, acc_std))
        print('rmse: %-.4f ± %-.4f' % (rmse, rmse_std))
    
    if args.save_clr != None:
        with open(args.save_clr, 'w', newline = '') as file:
            for i, clr in enumerate(best_qst_clr):
                file.write(f'{i},{clr}\n')