import random

import torch
from matplotlib import pyplot as plt
from torch import optim
from torchvision.models import VGG16_Weights, vgg16
from torchvision.transforms import Compose, GaussianBlur, RandomRotation
from tqdm import tqdm


def visualize_tensor(tensor, title=None):
    tensor_to_visualize = tensor.detach().cpu() if tensor.is_cuda else tensor
    tensor_to_visualize = tensor_to_visualize.squeeze()
    plt.imshow(tensor_to_visualize.permute(1, 2, 0))
    if title:
        plt.title(title)
    plt.show()


def get_vgg_and_transform():
    transforms = VGG16_Weights.IMAGENET1K_V1.transforms()
    model = vgg16(weights=VGG16_Weights.DEFAULT)
    return model, transforms


def random_input(model, transforms, visualize=True):
    random_vec = torch.rand(100, 3, 224, 224)
    if visualize:
        sample_image = random_vec[0, :]
        visualize_tensor(sample_image, title='random tensor')
    input_vec = transforms(random_vec)
    output = model(input_vec)
    var, mean = torch.var_mean(output, dim=1)
    print("variance: ", var)
    print("mean: ", mean)


def get_transform():
    blur_operation = GaussianBlur(kernel_size=3)
    # resize_and_crop_operation = RandomResizedCrop(size=(224, 224))
    return Compose([
        blur_operation,
        RandomRotation(degrees=180),
    ])


def get_criterion(name, **kwargs):
    if name == "max_logit_loss":
        return MaxLogitLossCriterion(**kwargs)
    elif name == "cross_entropy_loss":
        return torch.nn.CrossEntropyLoss()
    else:
        raise NotImplementedError


class MaxLogitLossCriterion:

    def __init__(self, **params):
        self.lamb_l2 = params["lambda_l2"]
        self.lamb_l1 = params["lambda_l1"]

    def __call__(self, logits, label):
        output = -logits[:, label] + \
                 self.lamb_l2 * torch.nn.functional.mse_loss(logits, torch.zeros_like(logits)) + \
                 self.lamb_l1 * torch.nn.functional.l1_loss(logits, torch.zeros_like(logits))
        return output


class ZeroLogitCriterion:

    def __call__(self, logits, label):
        output = torch.nn.functional.mse_loss(logits, torch.zeros_like(logits))
        return output


def random_submask_index(mask, min_size=0, max_size=None):
    # find corners of the mask
    if max_size:
        assert max_size >= min_size, 'maximal possible size smaller than minimal possible size!'
    left, top = (mask == 1.).nonzero()[0][2].item(), (mask == 1.).nonzero()[0][3].item()
    right, bottom = (mask == 1.).nonzero()[-1][2].item() + 1, (mask == 1.).nonzero()[-1][3].item() + 1
    # print(left, top, right, bottom)

    # randomly select corners based on the min and max size
    crop_left = random.randint(left, right - min_size - 1)
    if max_size:
        crop_right = random.randint(crop_left + min_size, min(crop_left + max_size, right))
    else:
        crop_right = random.randint(crop_left + min_size, right)
    crop_top = random.randint(top, bottom - min_size - 1)
    if max_size:
        crop_bottom = random.randint(crop_top + min_size, min(crop_top + max_size, bottom))
    else:
        crop_bottom = random.randint(crop_top + min_size, bottom)
    # print(crop_left, crop_top, crop_right, crop_bottom)
    return crop_left, crop_right, crop_top, crop_bottom


def crop_optimization_result(
        optimization_result, original_input, mask, min_size=50, max_size=None, preserve=True, visualize=False):
    crop_mask = torch.zeros(1, 3, 224, 224)
    crop_left, crop_right, crop_top, crop_bottom = random_submask_index(mask, min_size, max_size)
    crop_mask[:, :, crop_left:crop_right, crop_top:crop_bottom] = 1.

    crop_mask = crop_mask if preserve else 1 - crop_mask
    optimization_result = crop_mask * optimization_result[None, :] + (1 - crop_mask) * original_input
    if visualize:
        visualize_tensor(optimization_result.squeeze(), title='cropped image')

    # if not preserve, the crop mask is more complicated
    if not preserve:
        crop_mask = torch.clip(mask - crop_mask, min=0, max=1)

    # return based on mask
    optimization_result = mask * optimization_result + (1 - mask) * original_input
    return optimization_result, crop_mask


class ActivationMapGenerator:

    def __init__(self, get_model_and_transform=None, criterion=None, transforms=None, device=torch.device('cpu')):
        assert criterion, 'Please give a valid optimization criterion'
        assert get_model_and_transform, 'Please give a function that returns a model and a transform for the input of' \
                                        'the model.'
        self.criterion = criterion
        self.random_transforms = transforms
        self.model, self.transforms = get_model_and_transform()
        self.device = device
        self.model.to(device)
        self.background = None

    def optimize_background(self, lr=1e-1, num_steps=100):
        background_criterion = ZeroLogitCriterion()
        input_to_modify = torch.rand(1, 3, requires_grad=True, device=self.device)
        optimizer = optim.Adam([input_to_modify], lr=lr)
        for _ in tqdm(range(num_steps)):
            optimizer.zero_grad()
            expanded_input = input_to_modify[:, :, None, None].expand(-1, -1, 224, 224)
            transformed_input_to_modify = self.random_transforms(
                expanded_input) if self.random_transforms else expanded_input
            output = self.model(transformed_input_to_modify)

            # for background optimization, class label is not needed, pass dummy argument
            loss = background_criterion(output, None)
            loss.backward()
            optimizer.step()
        self.background = input_to_modify[:, :, None, None].expand(-1, -1, 224, 224)
        return self.background

    def optimize_maximal_visualization(self, class_label=None, lr=1e-1, num_steps=100, effected_region=None):
        assert class_label is not None, "please provide a valid class label!"
        input_to_modify = torch.rand(1, 3, 224, 224, requires_grad=True, device=self.device)
        init_input = torch.clone(input_to_modify)
        optimizer = optim.Adam([input_to_modify], lr=lr)
        for _ in tqdm(range(num_steps)):
            optimizer.zero_grad()
            transformed_input_to_modify = self.random_transforms(
                input_to_modify) if self.random_transforms else input_to_modify
            if effected_region is not None:
                # this constrains the optimization region
                effected_region = effected_region.to(self.device)

                # always use a random noise to prevent over-optimization at edge
                noise = torch.rand(1, 3, 224, 224, requires_grad=False, device=self.device)
                output = self.model(effected_region * transformed_input_to_modify + (1 - effected_region) * noise)
            else:
                output = self.model(transformed_input_to_modify)
            loss = self.criterion(output, torch.tensor([class_label]).to(self.device))
            loss.backward()
            optimizer.step()
        if effected_region is not None:
            input_to_modify = effected_region * input_to_modify + (1 - effected_region) * init_input
        return input_to_modify.detach().cpu(), init_input.detach().cpu()

    def prune_optimization_result(
            self,
            optimization_result,
            original_input,
            mask,
            min_size=0,
            max_size=None,
            num_prune_steps=None,
            preserve=True,
            debug=False):
        # apply crop multiple times, if a cropped part cause no activation, then remove
        assert num_prune_steps, 'Please provide a valid step for pruning'
        pruned_mask = torch.clone(mask)
        for _ in range(num_prune_steps):
            cropped_optimization_result, cropped_mask = crop_optimization_result(optimization_result,
                                                                                 original_input,
                                                                                 pruned_mask,
                                                                                 min_size=min_size,
                                                                                 max_size=max_size,
                                                                                 preserve=preserve)
            if debug:
                visualize_tensor(cropped_optimization_result, title="cropped_optimization_result")
            class_score = torch.nn.functional.softmax(
                self.model(cropped_optimization_result.to(self.device)), dim=1)[0][0]
            # TODO bug inside!
            if preserve and class_score < 0.01:
                print("Preserve mode, class score: {}, prune".format(class_score))
                pruned_mask -= cropped_mask
                pruned_mask = torch.clip(pruned_mask, min=0., max=1.)
            elif not preserve and class_score > 0.95:
                print("Non preserve mode, class score: {}, prune".format(class_score))
                pruned_mask -= cropped_mask
                pruned_mask = torch.clip(pruned_mask, min=0., max=1.)
            if debug:
                visualize_tensor(cropped_mask, title='cropped mask')
                visualize_tensor(pruned_mask, title='pruned mask')
        if debug:
            visualize_tensor(pruned_mask, title='pruned mask')
        pruned_opt_result = pruned_mask * optimization_result + (1 - pruned_mask) * original_input
        return pruned_opt_result, pruned_mask
