import os

from utils.func import *

def read_model(eval_path):
    if not os.path.isdir(eval_path):
        print('Invalid evaluation path ...')
        exit()

    return model_load(eval_path,'end')


def accuracy(models, loader, device, log):
    log.info('-'*50)
    
    ret = {}
    for key in models.keys():
        if 'net' not in  models[key].keys():
            pass
        else:
            model = models[key]['net']
            res_corr = []
            res_bias = []
            for _, data in enumerate(loader):
                x, y = data[0].to(device), data[1].to(device)
                
                logit = model(x)
                corr = (logit.max(1)[1] == y).float().detach().cpu()
                
                res_corr.extend(corr)
                res_bias.extend(data[2])
            
            res_bias = t.tensor(res_bias).int()
            res_corr = t.tensor(res_corr).float()

            ret[key] = {'major': t.mean(res_corr[t.where(res_bias==0)])*100, 
                        'minor': t.mean(res_corr[t.where(res_bias==1)])*100}

    log.info('*** Accuracy ***')
    for key_model in models.keys():
        if 'net' not in  models[key_model].keys():
            pass
        else:
            out_str = '[Model: %5s] \t ' %(key_model)
            for key_data_type in ret[key_model].keys():
                out_str += '%s: %.2f \t' %(key_data_type, ret[key_model][key_data_type])
            log.info(out_str)
    log.info('-'*50)
    

def evaluate(loaders, eval_path, device, log):
    models = read_model(eval_path)
    
    accuracy(models, loaders['test'], device, log)