#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
script that computes the importance (measured as the sum of the importance
of each level normalized by the total number of components for this level)
across 1,000 images of ImageNet and for different models

model zoo features : 
              - RT : 'augmix', 'pixmix', 'sin' (highest accuracy on ImageNet-C), 
              - AT : 'adv_free, fast_adv and adv,
              - ST : standard training ('baseline')

stores the results in a json dictionnary that stores indicates the model and saves it in the results/ 
directory (default)
"""


import sys
sys.path.append('../')

import tqdm
from torchvision.models import resnet50
import numpy as np
import os
import torch
from lib.helpers import load_imagenet_validation
from lib.wam_2D import WaveletAttribution2D
import json


# set ups and directorys
data_dir='../drafts/benchmark'
models_dir = '../drafts/model-weights'
target_dir="results"
batch_size=64
cases=['adv', 'adv_free', 'fast_adv', 'baseline']
num_levels=5


def sum_and_normalize(coeffs, area):
    return np.sum(coeffs, axis=(1,2)) / area

def load_model(models_dir, case):
    """
    loads the model corresponding to the 
    robustness case

    cases:
    'baseline': standard resnet50 ERM
    
    'adv_free' adversarially robust models
    'fast_adv'
    'adv'
    """
    if case == 'baseline':
        model = resnet50(pretrained = True).eval()
    elif case in ['adv_free', 'fast_adv', 'adv']:
        model = resnet50(pretrained = False) # model backbone #torch.load(os.path.join(models_dir, '{}.pth'.format(case))).eval()
        weights = torch.load(os.path.join(models_dir, "{}.pth".format(case)))
        model.load_state_dict(weights)
        model.eval()

    return model

# load the images and their labels
images, labels=load_imagenet_validation(data_dir,
                                           count=1000,
                                           seed=42)


nb_batch=int(np.ceil(1000/batch_size))

# indices to retrieve the wavelet coefficients 
img_size=224
level_indices=[int(img_size / 2**i) for i in range(num_levels+1)][::-1]
level_indices.insert(0,0)


results={}

for case in cases:
    print('Computing the WCAMs for the case ............... {}'.format(case))

    # load the model
    model = load_model(models_dir,case)

    # set up the explanier
    explainer=WaveletAttribution2D(model,
                         J=num_levels,
                         device="cuda",
                         method="integratedgrad",
                         approx_coeffs=False)

    # matrix that stores the importance for the model
    total_importance=np.empty((num_levels+1,1000), dtype=np.float32)

    for batch_index in tqdm.tqdm(range(nb_batch)):

        start_index=batch_index*batch_size
        end_index=min(batch_size*(batch_index+1), 1000)

        batch_images=images[start_index:end_index]
        batch_labels=labels[start_index:end_index]

        # compute the explanations
        explanations=explainer(batch_images,batch_labels)

        for j in range(len(level_indices)-1):
            level_start=level_indices[j]
            level_end=level_indices[j+1]

            if start_index==0: 
                # corresponds to approximation coefficients
                coeffs=explanations[:,:level_end,:end_index]
                area=coeffs.shape[1]*coeffs.shape[2]

                # returns a (,n_batch) array with the importance 
                # normalized by number of coefficients for the scale at hand
                approx_importance=sum_and_normalize(coeffs,area)
                total_importance[j,start_index:end_index]=approx_importance

            else:
                # area is the same for the three
                diag_coeff=explanations[:,level_start:level_end, level_start:level_end]

                area=diag_coeff.shape[1]*diag_coeff.shape[2]

                horz_coeff=explanations[:,:level_start,level_start:level_end]
                vert_coeff=explanations[:,level_start:level_end,:level_start]

                detail_importance=sum_and_normalize(diag_coeff, area) \
                                + sum_and_normalize(horz_coeff, area) \
                                + sum_and_normalize(vert_coeff, area)

                total_importance[j,start_index:end_index]=detail_importance

    # compute the mean importance of each level
    components_importance=np.nanmean(total_importance,axis=1).tolist()
    results[case]=components_importance
    print('Case {} completed'.format(case))

# export the file
with open(os.path.join(target_dir, 'robustness.json'), 'w') as f:
    json.dump(results, f)