
import os
import pickle
import numpy as np

def print_log_acc_bwt(acc, lss, output_path, file_name='logs.p'):

    print('*'*100)
    if acc.shape[0] <= 10: # looks awful if printing for 20 tasks
        print('Accuracies =')
        for i in range(acc.shape[0]):
            print('\t',end=',')
            for j in range(acc.shape[1]):
                print('{:5.4f}% '.format(acc[i,j]),end=',')
            print()

    avg_acc = np.mean(acc[acc.shape[0]-1,:])
    print ('ACC: {:5.4f}%'.format(avg_acc))
    #print()
    #print()
    # BWT calculated based on GEM paper (https://arxiv.org/abs/1706.08840)
    gem_bwt = sum(acc[-1]-np.diag(acc))/ (len(acc[-1])-1)
    # BWT calculated based on UCB paper (https://arxiv.org/abs/1906.02425)
    ucb_bwt = (acc[-1] - np.diag(acc)).mean()
    print ('BWT: {:5.2f}%'.format(gem_bwt))
    print()
    print ('BWT (UCB paper): {:5.2f}%'.format(ucb_bwt))

    print('*'*100)
    print('Done!')

    logs = {}
    # save results
    logs['name'] = output_path
    logs['acc'] = acc
    logs['loss'] = lss
    logs['gem_bwt'] = gem_bwt
    logs['ucb_bwt'] = ucb_bwt
    logs['rii'] = np.diag(acc) #Task accuracy after training
    logs['rij'] = acc[-1]  # Task accuracy after training on final task

    # pickle
    path = os.path.join(output_path, file_name)
    with open(path, 'wb') as output:
        pickle.dump(logs, output)

    print ("Log file saved in ", path)
    return avg_acc, gem_bwt

def load_pickle(filename):
    if not os.path.exists(filename):
        print('Warning: file "%s" does not exist!' % filename)
        return
    try:
        with open(filename, 'rb') as f:
            return pickle.load(f)
    except EOFError:
        print('Warning: log file corrupted!')

def save_pickle(data, path):
    try:
        with open(path, 'wb') as f:
            pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
            #print ('Saved %s..' %path)
    except:
        print('Could not save file %s' %(path))

def compute_gem_bwt(accs):
    return sum(accs[-1]-np.diag(accs))/ (len(accs[-1])-1)