import os
import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
import time
from democratization.core.lime import LIMECVInterpreter
from democratization.model_choices import *

import paddle.fluid as fluid


def prepare_paddle_model(model_fn):

    def paddle_model(image_input):
        class_num = 1000  # for ImageNet
        model = model_fn()
        logits = model.net(input=image_input, class_dim=class_num)
        if isinstance(logits, tuple):
            logits = logits[-1]
        probs = fluid.layers.softmax(logits, axis=-1)
        return probs

    return paddle_model


if __name__ == '__main__':
    from lime_democratization.args import args
    import glob

    print(args)

    image_path = args.image_path
    assert os.path.isdir(image_path)

    trained_model_path = args.trained_model
    model_name = trained_model_path.split('/')[-1].replace('_pretrained', '')
    print(model_name, trained_model_path)

    model_dict = model_choice_all()
    model_fn = model_dict[model_name]
    paddle_model = prepare_paddle_model(model_fn)

    num_output = args.num_classes

    more_params = {'num_samples': args.num_samples, 'batch_size': args.batch_size, 'target_size': args.target_size}
    save_dir = os.path.join(args.outdir, model_name)
    os.makedirs(args.outdir, exist_ok=True)

    image_list = sorted(glob.glob(os.path.join(image_path, '*')))
    if args.start is None or args.end is None:
        pass
    else:
        image_list = [image_path for i, image_path in enumerate(image_list) if args.start <= i <= args.end]

    total_results = {'image_size': args.target_size}
    save_path = f'{save_dir}_imagenet_size_{args.target_size}_lime_s{args.num_samples}_{args.start}_{args.end}.npz'
    lime_interpreter = LIMECVInterpreter(
        paddle_model, trained_model_path, model_input_shape=[3, args.target_size, args.target_size]
    )

    if os.path.exists(save_path):
        total_results = dict(np.load(save_path, allow_pickle=True))
    
    for i, image_i in enumerate(image_list):
        if image_i.lower().endswith(('.jpeg', '.png', 'jpg')):
            print(f'{time.time()} processing {image_i}, {i}/{len(image_list)}')

            if image_i in total_results:
                continue

            r = {}
            lime_weights = lime_interpreter.interpret(image_i, **more_params)
            r['lime_weights'] = lime_weights
            r['probability'] = lime_interpreter.lime_intermediate_results['probability']

            total_results[image_i] = r

            if i % 100 == 0:
                np.savez(save_path, **total_results)

    np.savez(save_path, **total_results)
