#!/usr/bin/env python
# -*- coding: utf-8 -*-

# libraries and imports
import sys
sys.path.append('../')

from lib.helpers import load_imagenet_validation
import json
import timm
import os
import gc
from lib.evaluators import Eval2DWAM, EvalImageBaselines



# import stuff related to the parameters of the script
import argparse

# parameter set up
parser = argparse.ArgumentParser(description = 'Evaluation of the 2D feature attribution methods')

# set the model, method and metric
parser.add_argument('--model', default='efficientnet', help="type of model to consider", type=str)
parser.add_argument('--method', default="gradcampp", help="attribution method to consider", type=str)
parser.add_argument('--metric', default=None, help="evaluation metric to consider", type=str)

# directories
parser.add_argument('--source_dir', default=None, help="path to the image net images and labels", type=str)
parser.add_argument('--destination_dir', default='results', help="path to the folder where the results are stored", type=str)

# misc arguments
parser.add_argument('--sample_size', default=1000, help="number of images to consider for the evaluation", type=int)
parser.add_argument('--device', default='cuda', help="device", type=str)
parser.add_argument('--batch_size', default=8, help="batch size", type=int)
parser.add_argument('--random_seed', default=42, help="random seed", type=int)


args=parser.parse_args()

if args.metric is None:
    eval_metric="All metrics"
else:
    eval_metric=args.metric

print('Evaluation is carried out for the method {}, using the model {}, for the metric: {}'.format(args.method, args.model, eval_metric))

# advanced parameters: to be modified directly into the script for simplicity
# overall parameters for the Eval2DWCAM
wavelet="haar"
J=3
mode="reflect"
transform=None
approx_coeffs=False
nb_samples=25
stdev_spread=0.25
random_seed=42

# helper functions 
def evaluate_2d_wcam(metric, model, input_images, input_labels, batch_size):

    # prepare the image, label batch to avoid cuda overloads
    nb_batch=int(len(input_images)/batch_size)

    if nb_batch>0:

        images, labels=[], []

        for i in range(nb_batch):


            start_index=i*batch_size
            end_index=min(len(input_labels), (i+1)*batch_size)

            images.append(
                input_images[start_index:end_index,:,:,:]
            )
            labels.append(
                input_labels[start_index:end_index]
            )

    else:
        images=[input_images]
        labels=[input_labels]
    # initialize the wcam test
    wcam_test=Eval2DWAM(model, method="integratedgrad", batch_size=batch_size)

    # dictionnary that will be returned with the raw scores
    scores={}

    if args.metric=="mu_fidelity":

        scores[metric]=[]

        for image_batch, label_batch in zip(images, labels):
            scores[metric].append(wcam_test.mu_fidelity(image_batch,label_batch))

        # flatten the list
        scores[metric]=sum(scores[metric], [])

    elif args.metric=="insertion":

        scores[metric]=[]

        for image_batch, label_batch in zip(images, labels):
            scores[metric].append(wcam_test.insertion(image_batch,label_batch))

        scores[metric]=sum(scores[metric], [])

    elif args.metric=="deletion":

        scores[metric]=[]

        for image_batch, label_batch in zip(images, labels):
            scores[metric].append(wcam_test.deletion(image_batch,label_batch))

        scores[metric]=sum(scores[metric], [])

    elif args.metric is None: # compute all three scores at once

        scores['mu_fidelity']=[]
        scores['insertion']=[]
        scores['deletion']=[]
                
        for image_batch,label_batch in zip(images, labels):

            scores['mu_fidelity'].append(wcam_test.mu_fidelity(image_batch, label_batch))
            scores['insertion'].append(wcam_test.insertion(image_batch, label_batch))
            scores['deletion'].append(wcam_test.deletion(image_batch, label_batch))

        # flatten the lists
        scores["mu_fidelity"]=sum(scores["mu_fidelity"], [])
        scores["insertion"]=sum(scores["insertion"], [])
        scores["deletion"]=sum(scores["deletion"], [])

    return scores

def evaluate_baseline(method, metric, model, input_images, input_labels, batch_size,layers=None, reshape_transform=None):

    # prepare the image, label batch to avoid cuda overloads
    nb_batch=int(len(input_images)/batch_size)

    if nb_batch>0:

        images, labels=[], []

        for i in range(nb_batch):

            start_index=i*batch_size
            end_index=min(len(input_labels), (i+1)*batch_size)

            images.append(
                input_images[start_index:end_index,:,:,:]
            )
            labels.append(
                input_labels[start_index:end_index]
            )

    else:
        images=[input_images]
        labels=[input_labels]

    # initialize the eplanation method test
    baseline_test=EvalImageBaselines(method,model,layers=layers,batch_size=batch_size)

    # dictionnary that will be returned with the raw scores
    scores={}

    if args.metric=="mu_fidelity":

        scores[metric]=[]

        for image_batch, label_batch in zip(images, labels):
            scores[metric].append(baseline_test.mu_fidelity(image_batch,label_batch))
            gc.collect()

        # flatten the list
        scores[metric]=sum(scores[metric], [])

    elif args.metric=="insertion":

        scores[metric]=[]

        for image_batch, label_batch in zip(images, labels):
            scores[metric].append(baseline_test.insertion(image_batch,label_batch))
            gc.collect()

        scores[metric]=sum(scores[metric], [])

    elif args.metric=="deletion":

        scores[metric]=[]

        for image_batch, label_batch in zip(images, labels):
            scores[metric].append(baseline_test.deletion(image_batch,label_batch))
            gc.collect()

        scores[metric]=sum(scores[metric], [])

    elif args.metric is None: # compute all three scores at once

        scores['mu_fidelity']=[]
        scores['insertion']=[]
        scores['deletion']=[]
                
        for image_batch,label_batch in zip(images, labels):

            scores['mu_fidelity'].append(baseline_test.mu_fidelity(image_batch, label_batch))
            scores['insertion'].append(baseline_test.insertion(image_batch, label_batch))
            scores['deletion'].append(baseline_test.deletion(image_batch, label_batch))
            gc.collect()

        # flatten the lists
        scores["mu_fidelity"]=sum(scores["mu_fidelity"], [])
        scores["insertion"]=sum(scores["insertion"], [])
        scores["deletion"]=sum(scores["deletion"], [])

    return scores


def main():

    # load the model
    if args.model=="resnet":
        model=timm.create_model('resnet18', pretrained=True)
        layer=model.layer4[-1]


    elif args.model=='convnext':
        # loads the convnext pretrained on IN 22K
        model = timm.create_model('convnext_small.fb_in22k_ft_in1k_384', pretrained=True)
        layer=model.norm_pre

    elif args.model=="vit":
        # a DeiT pretrained with dino
        model = timm.create_model('deit_tiny_patch16_224.fb_in1k', pretrained=True)
        layer=model.blocks[-1].norm1

    elif args.model=='efficientnet':
        model = timm.create_model('tf_efficientnet_b0.ns_jft_in1k', pretrained=True)
        layer=model.conv_head

    model.eval().to(args.device)
        

    # load the images and the labels
    # look for the directory specified by the user
    images, labels=load_imagenet_validation(args.source_dir,
                                           count=args.sample_size,
                                           seed=args.random_seed)


    if args.method=="wcam":

        results=evaluate_2d_wcam(args.metric, model, images, labels, args.batch_size)

    else:
        results=evaluate_baseline(args.method,
                                  args.metric,
                                  model,
                                  images,
                                  labels,
                                  args.batch_size,
                                  layers=layer,
                                  )

    # store the results in the specified directory
    if not os.path.exists(args.destination_dir):
        os.mkdir(args.destination_dir)


    # save the results 
    if args.metric is not None:
        filename='results_evaluation_{}_{}_{}.json'.format(args.metric, args.method, args.model)
    else:
        filename="results_evaluation_all_metrics_{}_{}.json".format(args.method,args.model)

    with open(os.path.join(args.destination_dir, filename), 'w') as f:
        json.dump(results, f)

        
# run the script

if __name__ == '__main__':
    main()