#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Script that computes the randomization test from Adebayo et al (2018)
for the two variants of the GradWCAM. 
"""

## libraries
import sys
sys.path.append('../')


import numpy as np
import timm
import os
from lib.helpers import load_imagenet_validation, load_vision_model
from lib.viewers import plot_wam
from lib.wam_2D import WaveletAttribution2D
import tqdm
from scipy.stats import spearmanr
import json


# Set ups

# set up the directory with the images, the model weights
# and the parameters

# directories
source="../drafts"
weights_dir=os.path.join(source,"model-weights/pytorch_model.bin")
data_dir="../drafts/benchmark"
export_dir='results'


# parameters
device="cuda"
batch_size=64


# load the images
images, labels=load_imagenet_validation(data_dir,
                                           count=1000,
                                           seed=42)

# load the models
model=load_vision_model(device=device,checkpoint_path=weights_dir)# load the trained model
random_model=timm.create_model('resnet18', pretrained=False).to(device) # load the random model

# helper function that gradually randomizes the model with the weights
# from the random model
def gradually_randomize(randomized_model, random_model, depth):
    """
    randomize the model up to the specified depth
    """
    # Iterate over the layers of both models using `named_modules`
    for (name, layer), (_, rand_layer) in zip(randomized_model.named_children(), random_model.named_children()):

        
        # Replace the layer in randomized_model with the corresponding layer in random_model
        setattr(randomized_model, name, rand_layer)
        if name==depth:
            break

    return randomized_model


# set up the layers
depths=["none","conv1","bn1","act1","maxpool","layer1","layer2","layer3","layer4","global_pool","fc"]

# dictionnary that stores the wcams
examples={
    d:{
        'smooth': np.empty((1000, 224,224), dtype=np.float32),
        'itegrad': np.empty((1000,224,224), dtype=np.float32)
    } for d in depths
}

# set up the explainer for the non random model
explainer_baseline_smooth=WaveletAttribution2D(model, device=device)
explainer_baseline_ig=WaveletAttribution2D(model, device=device, method="integratedgrad")

# main loop:
nb_batch=np.ceil(1000/batch_size).astype(int)
for batch_index in tqdm.tqdm(range(nb_batch)):

    start_index=batch_index*batch_size
    end_index=min(1000,(batch_index+1)*batch_size)
    
    # get the set of images
    x=images[start_index:end_index]
    y=labels[start_index:end_index]
    


    # compute the expanations for the baselines
    examples["none"]['smooth'][start_index:end_index,:,:]=explainer_baseline_smooth(x,y)
    examples["none"]['itegrad'][start_index:end_index,:,:]=explainer_baseline_ig(x,y)

    # compute the explanations for the increasingly randomized model
    for d in depths[1:]:

        # generate the model
        model_randomized=gradually_randomize(model,random_model,d)

        # generate the explainer
        r_explainer_smooth=WaveletAttribution2D(model_randomized,device=device)
        r_explainer_itegrad=WaveletAttribution2D(model_randomized,method="integratedgrad",device=device)

        # retrieve the explanations
        examples[d]['smooth'][start_index:end_index,:,:]=r_explainer_smooth(x,y)
        examples[d]['itegrad'][start_index:end_index,:,:]=r_explainer_itegrad(x,y)

print('Computation of the WCAMs completed')


# compute the correlations 
correlations_smooth=np.empty((1000,len(depths)), dtype=np.float32)
correlations_itegrad=np.empty((1000,len(depths)), dtype=np.float32)

for i in range(1000):
    
    correlations_smooth[i,:]=[
    spearmanr(examples["none"]['smooth'][i,:,:].flatten(),examples[d]['smooth'][i,:,:].flatten()).statistic for d in depths
]
    
    correlations_itegrad[i,:]=[
    spearmanr(examples["none"]['itegrad'][i,:,:].flatten(),examples[d]['itegrad'][i,:,:].flatten()).statistic for d in depths
]

with open('results_smooth.json', "w") as f:
    json.dump(correlations_smooth.tolist(), f)

with open('results_itegrad.json', "w") as f:
    json.dump(correlations_itegrad.tolist(), f)