"""
Be careful for running out of memory.
"""

import os
import numpy as np


def get_available_info(model_mapping_file):
    model_names = []
    f = open(model_mapping_file, 'r')
    lines = f.readlines()
    for line in lines:
        model_info = line.split(' ')

        if len(model_info) == 3:
            if float(model_info[2]) < 0.1:
                continue

        model_names.append(model_info[0])

    return model_names


def read_one_file(filename):
    dict_results_one_model = np.load(filename, allow_pickle=True).item()
    assert isinstance(dict_results_one_model, dict)

    return dict_results_one_model


def process_one_image(model_results, image_name, model_names):
    voted_lime_weights = {}
    probabilities = {}

    num_models = 0
    for model_name in model_names:
        one_model_result = model_results[model_name][image_name]
        lime_weights = one_model_result['lime_weights']
        probability = one_model_result['probabilities']

        normalized_lime_weights = {}
        num_models += 1
        for i, predicted_label in enumerate(lime_weights.keys()):
            p = probabilities.get(predicted_label, 0.0)
            p += probability[i]
            probabilities[predicted_label] = p

            normalized_lime_weights_y = normalized_lime_weights.get(predicted_label, {})
            # l1
            norm = sum([abs(w[1]) for w in lime_weights[predicted_label]])
            # max
            # norm = max([abs(w[1]) for w in lime_weights[predicted_label]])
            # l2
            # norm = np.sqrt(sum([w[1] * w[1] for w in lime_weights[predicted_label]]))
            for seg_label, seg_weight in lime_weights[predicted_label]:
                normalized_lime_weights_y[seg_label] = seg_weight / norm

            normalized_lime_weights[predicted_label] = normalized_lime_weights_y

            if predicted_label not in voted_lime_weights:
                voted_lime_weights[predicted_label] = normalized_lime_weights[predicted_label]
            else:
                total_weights_y = voted_lime_weights[predicted_label]
                for seg_label in normalized_lime_weights[predicted_label].keys():
                    total_weights_y[seg_label] += normalized_lime_weights[predicted_label][seg_label]

                voted_lime_weights[predicted_label] = total_weights_y

    # transform to LIME local_weights format.
    for predicted_label in voted_lime_weights.keys():
        total_weights_y = voted_lime_weights[predicted_label]

        total_weights_y = [(seg_index, total_weights_y[seg_index] / num_models) for seg_index in total_weights_y]
        total_weights_y = sorted(
            total_weights_y, key=lambda x: np.abs(x[1]), reverse=True
        )

        voted_lime_weights[predicted_label] = total_weights_y

    # normalize
    probabilities = {y: probabilities[y] / num_models for y in probabilities}

    return voted_lime_weights, probabilities


def compute_score_two_sets(set1, set2):
    """

    Args:
        set1: element :(sp_id, sp_weights)
        set2: element: (sp_id, sp_weights)

    Returns: a scalar

    """
    assert len(set1) == len(set2), f"{len(set1)} is not equal to {len(set2)}"
    u = np.zeros(len(set1))
    v = np.zeros(len(set2))
    for sp_id, sp_weights in set1:
        u[sp_id] = sp_weights
    for sp_id, sp_weights in set2:
        v[sp_id] = sp_weights

    uv = np.average(u * v)
    uu = np.average(np.square(u))
    vv = np.average(np.square(v))
    cosine_similarity = uv / np.sqrt(uu * vv)

    return cosine_similarity


def compute_score_one_image(model_results, image_name, model_names=None):
    """
    Args:
        img_name:
        load_dir:
        suffix:

    Returns:

    """

    if model_names is None:
        model_names = list(model_results.keys())

    voted_lime_weights, probabilities = process_one_image(model_results, image_name, model_names)
    #     print(totol_weights.keys(), probabilities)
    top_label = max(probabilities, key=probabilities.get)
    voted_lime_weights_top_label = voted_lime_weights[top_label]

    true_label = val12_gt[image_name]

    model_to_score_one_image = {model_name: 0.0 for model_name in model_names}
    model_to_performance_one_image = {model_name: 0.0 for model_name in model_names}
    for model_name in model_names:
        if model_name in [
            'MobileNetV3_small_x1_0'  # bad performance
        ]:
            continue

        model_one_image_results = model_results[model_name][image_name]
        assert isinstance(model_one_image_results, dict)
        model_to_performance_one_image[model_name] = 1 if list(model_one_image_results['lime_weights'].keys())[
                                                              -1] == true_label else 0
        if top_label in model_one_image_results['lime_weights'].keys():
            lime_weights = model_one_image_results['lime_weights'][top_label]
            model_to_score_one_image[model_name] = compute_score_two_sets(lime_weights, voted_lime_weights_top_label)
        else:
            lime_weights = model_one_image_results['lime_weights'][
                list(model_one_image_results['lime_weights'].keys())[-1]]
            model_to_score_one_image[model_name] = compute_score_two_sets(lime_weights, voted_lime_weights_top_label)


    return model_to_score_one_image, model_to_performance_one_image


mapping = '../local_scripts/imagenet_model_mapping_input224.txt'
suffix = '_lime_s3000.npy'
model_names = sorted(get_available_info(mapping))

print(len(model_names))
val12_gt = np.load('../local_scripts/val12_gt.npy', allow_pickle=True).item()

# load results to memory
load_dir = '../imagenet_lime_results/'

model_results = {}
for model_name in model_names:
    filename = os.path.join(load_dir, f'{model_name}{suffix}')
    r = read_one_file(filename)
    model_results[model_name] = r

# compute score and performance of all models.
from multiprocessing import Pool

image_paths = list(model_results['AlexNet'].keys())
chosen_model_names = [m for m in model_names]


def pool_wrapper_compute_score_one_image(i):
    image_name = image_paths[i].split('/')[-1]
    if image_name in model_results['AlexNet']:
        score, performance = compute_score_one_image(model_results, image_name, chosen_model_names)
        return score, performance

pool = Pool(processes=16)
results = pool.map(pool_wrapper_compute_score_one_image, range(len(image_paths)))
results = [r for r in results if r is not None]
pool.close()

model_to_scores = {}
model_to_performances = {}
for score, performance in results:
    model_to_scores = {
        model_name: score.get(model_name, 0.0) + model_to_scores.get(model_name, 0.0)
        for model_name in chosen_model_names
    }

    model_to_performances = {
        model_name: performance.get(model_name, 0.0) + model_to_performances.get(model_name, 0.0)
        for model_name in chosen_model_names
    }

print(model_to_scores)
print(model_to_performances)
