#!/usr/bin/env python
# -*- coding: utf-8 -*-

# libraries and imports
import sys
sys.path.append('../')

# import stuff related to the parameters of the script
import argparse
import numpy as np
from lib.evaluators import Eval1DWAM, EvalAudioBaselines
import ptwt
import pywt
import torch
from lib.helpers import load_audio_model, load_sound
import os
import pandas as pd
import tqdm
import json

# parameter set up
parser = argparse.ArgumentParser(description = 'Evaluation of the 1D feature attribution methods')

# set the model, method and metric
parser.add_argument('--method', default='gradcam', help="attribution method to consider", type=str)
parser.add_argument('--metric', default=None, help="evaluation metric to consider", type=str)
parser.add_argument('--noise', default=False, help="whether we should consider noisy samples or not", type=str)

# directories
parser.add_argument('--source_dir', default=None, help="path to data to consider", type=str)
parser.add_argument('--destination_dir', default='results/audio', help="path to the folder where the results are stored", type=str)

# misc arguments
parser.add_argument('--sample_size', default=400, help="number of images to consider for the evaluation", type=int)
parser.add_argument('--device', default='cuda', help="device", type=str)
parser.add_argument('--batch_size', default=8, help="batch size", type=int)
parser.add_argument('--random_seed', default=42, help="random seed", type=int)


args=parser.parse_args()

if args.metric is None:
    eval_metric="All metrics"
else:
    eval_metric=args.metric

print('Evaluation is carried out for the method {} for the metric: {}. Noise added: {}'.format(args.method, eval_metric, args.noise))

# advanced parameters: to be modified directly into the script for simplicity
# overall parameters for the Eval1DWCAM
wavelet="haar"
J=3
mode="reflect"
approx_coeffs=False
n_mels=128
n_fft=1024
sample_rate=44100
n_samples=25 
stdev_spread=0.001 # after visual inspection
random_seed=42

# misc
n_iter=32 # number of steps to compute the auc for insertion and deletion

# method for the wcam
approach="integratedgrad" # should be smooth or integratedgrad
target="melspec" # should be wavelet or melspec



# helper that loads the data given the specified directories 
# and returns it in a correctly specified way

# directories to the dataset and to the model
# change here if paths are not found or structure 
# is not the same as the one provided in the original directory 
root_dir=os.path.join(args.source_dir,"L2I_code/datasets/ESC50")
root_model_dir=os.path.join(args.source_dir,"L2I_code/output/esc50_output") 


# load the sounds and pick a batch of samples
df_sounds=pd.read_csv(os.path.join(root_dir, 'meta/esc50.csv'))
picked_samples=[d for d in df_sounds['filename'] if d[0]=='1']
#picked_samples=df_sounds['filename'].sample(n=args.sample_size, random_state=42).tolist()

# batchify the implementation to avoid loading all sounds at once 
batched_samples=[]

nb_batch=int(np.ceil(args.sample_size/args.batch_size))

for batch_index in range(nb_batch):
    start_index=batch_index*args.batch_size
    end_index=min(args.sample_size, (batch_index+1) *args.batch_size)

    batched_samples.append(
        picked_samples[start_index:end_index]
    )


# helper to evaluate the wcam

def evaluate_wcam(batched_samples, noise, model, target, metric, approach, n_iter):
    """
    evaluate the wcam according to the metrics specified as argument
    if the metric is non, then computes evetything
    """

    evaluator=Eval1DWAM(model,batch_size=args.batch_size,wavelet=wavelet,J=J, method=approach, mode=mode,device=args.device,
                         approx_coeffs=approx_coeffs,n_mels=n_mels,n_fft=n_fft,sample_rate=sample_rate,n_samples=n_samples,
                         stdev_spread=stdev_spread,random_seed=random_seed)
    

    # initialize the keys
    if metric is None:
        results={
            'insertion' :[],
            "deletion"  :[],
            'ff'        :[],
            'inputfid'  :[]
        }
    
    else:
        results={metric:[]}

    for batch_sample in tqdm.tqdm(batched_samples):


        # load the data
        data=load_sound(root_dir, batch_sample, noise=noise)

        if metric=="insertion":
            results[metric].append(
                evaluator.insertion(data['x'], data['y'], target, n_iter)
            )

        elif metric=="deletion":
            results[metric].append(
                evaluator.deletion(data['x'], data['y'], target, n_iter)
            )

        elif metric=="ff":
            results[metric].append(
                evaluator.faithfulness_of_spectra(data['x'], data['y'], target)
                )
            
        elif metric=="inputfid":
            results[metric].append(
                evaluator.input_fidelity(data['x'], data['y'], target)
            )

        elif metric is None: # compute all metrics
            results["insertion"].append(
                evaluator.insertion(data['x'], data['y'], target, n_iter)
            )

            results["deletion"].append(
                evaluator.deletion(data['x'], data['y'], target, n_iter)
            )

            results["ff"].append(
                evaluator.faithfulness_of_spectra(data['x'], data['y'], target)
            )

            results["inputfid"].append(
                evaluator.input_fidelity(data['x'], data['y'], target)
            )

        # empty caches
        del data

    # postprocess the data
    for key in results.keys():
        results[key]=list(sum(results[key], []))

    return results

# evaluation of the baselines
def evaluate_baseline(method_name,batched_samples, noise, model, metric, n_iter, batch_size):
    """
    evaluates the baseline method
    """
    evaluator=EvalAudioBaselines(method_name,model,
                                 device=args.device,n_fft=n_fft,sample_rate=sample_rate,n_mels=n_mels,batch_size=batch_size)
    
    # initialize the keys
    if metric is None:
        results={
            'insertion' :[],
            "deletion"  :[],
            'ff'        :[],
            'inputfid'  :[]
        }
    else:
        results={metric:[]}

    for batch_sample in tqdm.tqdm(batched_samples):
        # load the data 
        data=load_sound(root_dir, batch_sample, noise=noise)

        if metric=="insertion":
            results[metric].append(
                evaluator.insertion(data['x'], data['y'], n_iter)
            )

        elif metric=="deletion":
            results[metric].append(
                evaluator.deletion(data['x'], data['y'], n_iter)
            )

        elif metric=="ff":
            comp=evaluator.faithfulness_of_spectra(data['x'], data['y'])
            results[metric].append(
                comp
                )
            
        elif metric=="inputfid":
            results[metric].append(
                evaluator.input_fidelity(data['x'], data['y'])
            )

        elif metric is None: # compute all metrics
            results["insertion"].append(
                evaluator.insertion(data['x'], data['y'], n_iter)
            )

            results["deletion"].append(
                evaluator.deletion(data['x'], data['y'], n_iter)
            )

            results["ff"].append(
                evaluator.faithfulness_of_spectra(data['x'], data['y'])
            )

            results["inputfid"].append(
                evaluator.input_fidelity(data['x'], data['y'])
            )

        # empty caches
        del data

    # postprocess the data
    for key in results.keys():

        results[key]=list(sum(results[key], []))

    return results



def main():

    # load the model
    model=load_audio_model(root_model_dir,device=args.device)

    # evaluate
    if args.method=="wcam":
        # do stuff here$
        results= evaluate_wcam(batched_samples,args.noise,model,target,args.metric, approach,n_iter)

    elif args.method in ["gradcam", "integratedgrad", "smoothgrad", "saliency"]:
        results = evaluate_baseline(args.method,batched_samples, args.noise, model, args.metric, n_iter, args.batch_size)


    # save the results
    # store the results in the specified directory
    if not os.path.exists(args.destination_dir):
        os.mkdir(args.destination_dir)

    # save the results 
    if args.metric is not None:
        filename='results_evaluation_{}_{}_{}.json'.format(args.metric, args.method, args.noise)
    else:
        filename="results_evaluation_all_metrics_{}_{}.json".format(args.method, args.noise)

    with open(os.path.join(args.destination_dir, filename), 'w') as f:
        json.dump(results, f)


if __name__ == '__main__':
    main()