import os.path as osp
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from lucent.optvis import param, render
from lucent.optvis.objectives import handle_batch, wrap_objective

from sde.legacy.model_visualization import *


@wrap_objective()
def label(n_channel, batch=None):
    """Visualize a single label while supress the rest"""

    @handle_batch(batch)
    def inner(model):
        return -model("labels")[:, n_channel].mean() + model("labels")[:, :n_channel].mean() + model(
            "labels")[:, n_channel + 1:].mean()

    return inner


def generate_opt_result_lucent(model, class_label, optimization_steps=256, device=None):
    obj = "classifier:" + str(class_label)
    obj = label(class_label)
    images = render.render_vis(
        model, obj, show_inline=True, param_f=lambda: param.image(224), thresholds=(optimization_steps,))
    return torch.tensor(images[0]).permute(0, 3, 1, 2).squeeze()


def generate_opt_result(
        class_label,
        criterion_param,
        mask_region,
        optimization_steps,
        prune,
        prune_preserve,
        num_prune_steps,
        min_size,
        max_size,
        visualize,
        prompt,
        device):
    # random_input(model, transforms)
    mask = torch.zeros(1, 3, 224, 224)
    if mask_region is not None:
        mask[:, :, mask_region["left"]:mask_region["right"], mask_region["bottom"]:mask_region["top"]] = 1.
    criterion = get_criterion(criterion_param["name"], **criterion_param["params"])
    transform = get_transform()
    generator = ActivationMapGenerator(get_vgg_and_transform, criterion, transform, device=device)

    # TODO a simple work around to use lucent, can be improved later
    if criterion_param["method"] == "lucent":
        print("Optimize with Lucent library")
        print("Run dummy generator for one step")
        original_input = generator.optimize_background(num_steps=optimization_steps)
        optimization_result, _ = generator.optimize_maximal_visualization(class_label=class_label,
                                                                          num_steps=1,
                                                                          effected_region=mask)
        optimization_result = generate_opt_result_lucent(
            generator.model, class_label, optimization_steps=optimization_steps, device=device)
    else:
        optimization_result, original_input = generator.optimize_maximal_visualization(class_label=class_label,
                                                                                       num_steps=optimization_steps,
                                                                                       effected_region=mask)
    optimization_result = optimization_result.squeeze()
    if visualize:
        visualize_tensor(optimization_result, title='Optimization result')
    if prompt:
        print(
            'model_output given original input: {}'.format(
                F.softmax(generator.model(original_input.to(device)), dim=1)[0][0]))
        print(
            'model_output given optimization result: {}'.format(
                F.softmax(generator.model(optimization_result[None, :].to(device)), dim=1)[0][0]))

    # prune optimization result
    if prune:
        pruned_result, pruned_mask = generator.prune_optimization_result(optimization_result,
                                                                         original_input,
                                                                         mask,
                                                                         min_size=min_size,
                                                                         max_size=max_size,
                                                                         preserve=prune_preserve,
                                                                         num_prune_steps=num_prune_steps)
        if visualize:
            print("Show pruned result")
            visualize_tensor(pruned_result, title="pruned result")
        optimization_result = pruned_result.squeeze()
        # mask = pruned_mask
    return generator, optimization_result, original_input, mask


def evaluate_background(model, background):
    logit = model(background)[0]
    zero_input_logit = model(torch.zeros_like(background))[0]
    plt.plot(range(0, logit.shape[0]), logit.detach().cpu(), 'b', linestyle='None', marker=".")
    plt.plot(range(0, logit.shape[0]), zero_input_logit.detach().cpu(), 'r', linestyle='None', marker=".")
    plt.ylim(-3, 3)
    plt.title("logit on background")
    plt.show()


def generate_and_random_crop(
    class_label=None,
    criterion=None,
    mask_region=None,
    visualize=False,
    optimization_steps=None,
    prune=False,
    num_prune_steps=None,
    prune_preserve=True,
    crop_preserve=True,
    num_crop_test=None,
    show_statistic=False,
    min_size=0,
    max_size=None,
    prompt=False,
    device=None,
):
    generator, optimization_result, original_input, mask = generate_opt_result(
        class_label,
        criterion,
        mask_region,
        optimization_steps,
        prune,
        prune_preserve,
        num_prune_steps,
        min_size,
        max_size,
        visualize,
        prompt,
        device,
    )

    # evaluate background
    visualize_tensor(original_input, title="optimized background")
    evaluate_background(generator.model, original_input)
    original_input = original_input.cpu()

    # crop optimization result
    assert num_crop_test, "provide a valid number of test!"
    positive = 0
    class_score_list = []
    for _ in range(num_crop_test):
        cropped_optimization_result, _ = crop_optimization_result(optimization_result,
                                                                  original_input,
                                                                  mask,
                                                                  min_size=min_size,
                                                                  max_size=max_size,
                                                                  preserve=crop_preserve,
                                                                  visualize=visualize)
        class_score = F.softmax(generator.model(cropped_optimization_result.to(device)), dim=1)[0][class_label]
        class_logit = generator.model(cropped_optimization_result.to(device))[0][class_label]
        print('model_output given cropped optimization result: {}'.format(class_score))
        print('class logit given cropped optimization result: {}'.format(class_logit))
        if class_score > 0.8:
            positive += 1
        class_score_list.append(class_score.item())
    if show_statistic:
        plt.hist(class_score_list)
        plt.title("Model Output Statistic")
        plt.show()
    print("Number of positive ratio: {}".format(positive / num_crop_test))


def find_effective_activation_region(
        class_label,
        criterion,
        mask_region,
        visualize,
        optimization_steps,
        prune,
        num_prune_steps,
        num_crop_test,
        min_size,
        max_size,
        show_statistic,
        prompt,
        device,
        prune_preserve,
        stripe=10,
        output_path=None):
    """The objective of this test is to crop images at different size,
    then observe the outcome to see if cropped image is representative for model"""
    generator, optimization_result, original_input, mask = generate_opt_result(
        class_label,
        criterion,
        mask_region,
        optimization_steps,
        prune,
        prune_preserve,
        num_prune_steps,
        min_size,
        max_size,
        visualize,
        prompt,
        device,
    )

    # print prediction result of the original input
    class_score = F.softmax(generator.model(optimization_result[None, :].to(device)), dim=1)[0][class_label]
    print("Original Prediction Score: {}".format(class_score))

    # create result dict for output storage
    result_dict = {
        "preserve_false": {
            "positive_ratio": [], "class_logit": [], "max_logit": [], "mean_logit": []
        },
        "preserve_true": {
            "positive_ratio": [], "class_logit": [], "max_logit": [], "mean_logit": []
        },
    }

    # crop optimization result
    for preserve in [False, True]:
        for size in range(min_size, max_size + stripe, stripe):
            assert num_crop_test, "provide a valid number of test!"
            positive = 0
            class_score_list, class_logit_list, mean_class_logit_list, max_logit_list, logit_list = [], [], [], [], []
            for _ in range(num_crop_test):
                cropped_optimization_result, _ = crop_optimization_result(optimization_result,
                                                                          original_input,
                                                                          mask,
                                                                          min_size=size,
                                                                          max_size=size,
                                                                          preserve=preserve,
                                                                          visualize=visualize)
                logit = generator.model(cropped_optimization_result.to(device))
                class_score = F.softmax(logit, dim=1)[0][0]
                class_logit = logit[0][class_label]
                mean_class_logit = logit[0].mean()
                max_logit = logit[0].max()
                if prompt:
                    print('model_output given cropped optimization result: {}'.format(class_score))
                    print('class logit given cropped optimization result: {}'.format(class_logit))
                    print('mean class logit given cropped optimization result: {}'.format(mean_class_logit))
                    print('max_logit logit given cropped optimization result: {}'.format(max_logit))
                if class_score > 0.8:
                    positive += 1
                class_score_list.append(class_score.item())
                class_logit_list.append(class_logit.item())
                mean_class_logit_list.append(mean_class_logit.item())
                max_logit_list.append(max_logit.item())
                logit_list.append(logit)
            # plot model output statistic for the target class
            plt.hist(class_score_list, range=(0, 1))
            plt.title("Target Output Statistic")
            filename = "target_output_preserve_{}_size_{}.png".format(preserve, size)
            plt.savefig(osp.join(output_path, filename))
            if show_statistic:
                plt.show()
            plt.clf()

            # plot avg logits of all trails
            logit_list_tensor = torch.cat(logit_list).mean(0).detach().cpu()
            plt.plot([i for i in range(1000)], logit_list_tensor)
            plt.title("Model Logit Statistic Crop Preserve {} Size {}".format(preserve, size))
            filename = "logit_preserve_{}_size_{}.png".format(preserve, size)
            plt.savefig(osp.join(output_path, filename))
            if show_statistic:
                plt.show()
            plt.clf()
            positive_ratio = positive / num_crop_test
            class_logit = np.mean(np.array(class_logit_list)).item()
            mean_logit = np.mean(np.array(mean_class_logit_list)).item()
            max_logit = np.mean(np.array(max_logit_list)).item()
            print(
                "crop preserve: {}, crop size: {}, Number of positive ratio: {}".format(preserve, size, positive_ratio))
            print("crop preserve: {}, crop size: {}, mean class logit: {}".format(preserve, size, class_logit))
            print("crop preserve: {}, crop size: {}, mean average class logit: {}".format(preserve, size, mean_logit))
            print("crop preserve: {}, crop size: {}, mean average max logit: {}".format(preserve, size, max_logit))
            curr_dict = result_dict["preserve_true"] if preserve else result_dict["preserve_false"]
            curr_dict["positive_ratio"].append(positive_ratio)
            curr_dict["class_logit"].append(class_logit)
            curr_dict["mean_logit"].append(mean_logit)
            curr_dict["max_logit"].append(max_logit)
    for preserve, value in result_dict.items():
        for output_name, output_list in value.items():
            if output_name == "positive_ratio":
                continue
            plt.plot([i for i in range(min_size, max_size + stripe, stripe)], output_list, label=output_name)
            plt.legend()
            plt.title(preserve)
        filename = preserve + ".png"
        plt.savefig(osp.join(output_path, filename))
        plt.clf()

        # plot positive ratio
        plt.plot([i for i in range(min_size, max_size + stripe, stripe)], value["positive_ratio"])
        plt.title(preserve + "_positive_ratio")
        plt.savefig(osp.join(output_path, preserve + "_positive_ratio.png"))
        plt.clf()


def main():
    seed = 42
    class_label = 0
    optimization_criterion = {
        "name": "max_logit_loss",
        "params": {
            "lambda_l1": 0.01,
            "lambda_l2": 0.01,
        },
        "method": "lucent",  # if use "lucent", we optimize with external library
    }
    prune = False
    visualize = False
    optimization_steps = 256
    prune_preserve = False
    num_prune_steps = 30
    crop_preserve = True
    num_crop_test = 100
    mask_region = {"left": 50, "right": 200, "bottom": 50, "top": 200}
    min_size = 120
    max_size = 120
    prompt = False
    show_statistic = False
    output_path = "output/"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if seed:
        random.seed(seed)
        torch.manual_seed(seed)

    ######################################
    # test model behavior
    # generate_and_random_crop(class_label=class_label,
    #                          criterion=optimization_criterion, mask_region=mask_region, visualize=visualize,
    #                          optimization_steps=optimization_steps, prune=prune, num_prune_steps=num_prune_steps,
    #                          num_crop_test=num_crop_test, min_size=min_size, max_size=max_size,
    #                          show_statistic=show_statistic, prune_preserve=prune_preserve, crop_preserve=crop_preserve,
    #                          device=device)

    #######################################
    # find effective activation region test
    test_min_size = 0
    test_max_size = 140
    find_effective_activation_region(
        class_label,
        optimization_criterion,
        mask_region,
        visualize,
        optimization_steps,
        prune,
        num_prune_steps,
        num_crop_test,
        test_min_size,
        test_max_size,
        show_statistic,
        prompt,
        device,
        prune_preserve,
        stripe=10,
        output_path=output_path)


if __name__ == "__main__":
    main()
