from finetuner import *
from saliency_analysis import load_cached_results, cache_results
from tqdm import tqdm

def eval_waterbirds(mkey, trial=None):
    ''' we assume that the model is already finetuned '''
    finetuner = FineTuner(mkey, dset='waterbirds', trial=trial)
    finetuner.restore_model()

    val_loader = finetuner.loaders['test']

    minority_cc, minority_ctr = 0, 0
    majority_cc, majority_ctr = 0, 0
    for x, y, meta in tqdm(val_loader):
        majority_inds = (meta[:,0] == meta[:,1])
        minority_inds = (meta[:,0] != meta[:,1])

        x, y = x.cuda(), y.cuda()
        preds = finetuner.model(x).argmax(1)
        minority_cc += (preds[minority_inds] == y[minority_inds]).sum()
        majority_cc += (preds[majority_inds] == y[majority_inds]).sum()

        minority_ctr += minority_inds.sum()
        majority_ctr += majority_inds.sum()

    return minority_cc / minority_ctr, majority_cc / majority_ctr


def eval_waterbirds_all_models():
    arches = ['resnet18', 'resnet50']
    l2_epsilons = [0,3]#[0, 0.25, 0.5, 1, 3, 5]
    linf_epsilons = []#[0.5, 1.0, 2.0, 4.0, 8.0]    

    results_path = './results/waterbirds_eval_best_val_saved2.pkl'
    # results_path = './results/waterbirds_eval_best_val_saved_all_trials.pkl'
    results = load_cached_results(results_path)
    # for trial in range(1,6):
    #     results[trial] = dict()
    trial=None
    # for arch in ['wide_resnet50_2']:#['resnet18', 'resnet50']:
    for arch  in ['mobilenet', 'densenet', 'shufflenet', 'resnext50_32x4d', 'vgg16_bn']:
        for adv_train_norm, epsilons in zip(['l2', 'linf'], [l2_epsilons, linf_epsilons]):
            for adv_train_eps in epsilons:
                mkey = f'{arch}_{adv_train_norm}_eps{adv_train_eps}'
                # if mkey not in results[trial]:
                if mkey not in results:
                    min_acc, maj_acc = eval_waterbirds(mkey, trial)
                    # results[trial][mkey] = dict({'minority': min_acc, 'majority': maj_acc})
                    results[mkey] = dict({'minority': min_acc, 'majority': maj_acc})
                    cache_results(results_path, results)
                # min_acc, maj_acc = [100*results[trial][mkey][x] for x in ['minority', 'majority']]
                min_acc, maj_acc = [100*results[mkey][x] for x in ['minority', 'majority']]
                print('Model: {:<20s}, Majority Acc: {:.2f}, Minority Acc: {:.2f}, Gap: {:.2f}'.format(
                    mkey, maj_acc, min_acc, maj_acc - min_acc))

if __name__ == '__main__':
    eval_waterbirds_all_models()



