from copy import deepcopy
import random
from typing import List, Tuple, Optional

import torch
import numpy as np
from torchvision import transforms


EPSILON = 0.001
DX = 0.01
DTHETA = 0.01
BIAS_VALUE = 0.02
N_COMPARISONS = 30


def perturb_model(model: torch.nn.Module, a: float = 0.0, b: float = 1.0, dtheta: float = 0.01,
                  seed: int = 0) -> Tuple[torch.nn.Module, float]:
    random.seed(seed)
    pt_model = deepcopy(model)
    with torch.no_grad():
        model_params = list(pt_model.parameters())

        for param in model_params:
            param += random.choice((-1, 1)) * dtheta * param

        # first select a parameter from the list of tensors. the parameters a and b allow us to define in which range
        # we want to perturb the model.
        # param_idx = random.randint(int((len(model_params)-1) * a), int((len(model_params)-1)*b))
        # # secondly we choose the exact value we want to perturb.
        # # we convert the shape to a list to make it mutable
        # val_idx_gen: List[int] = list(model_params[param_idx].shape)
        # for dim in range(len(val_idx_gen)):
        #     # we replace the shape at dim with a random integer thats smaller than the original value
        #     val_idx_gen[dim] = random.randint(0, val_idx_gen[dim]-1)
        # # we convert back to a tuple to use the new indices for indexing our model parameter below
        # val_idx = tuple(val_idx_gen)
        # model_params[param_idx][val_idx] += random.choice((-1, 1)) * dtheta
        # perturbation = dtheta  # * torch.ones(model_params[param_idx].size()) * model_params[param_idx]
        # model_params[param_idx] += random.choice((-1, 1)) * perturbation
    return pt_model


def perturb_input(data: torch.Tensor, frac: float = 0.0, dx: float = 0.1, seed: int = 0) -> torch.Tensor:
    random.seed(seed)
    data_perturb = deepcopy(data)
    # we convert the shape to a list to make it mutable
    val_idx_gen: List[int] = list(data.shape)
    for dim in range(len(val_idx_gen)):
        # we replace the shape at dim with a random integer thats smaller than the original value
        val_idx_gen[dim] = random.randint(0, val_idx_gen[dim]-1)
    # we convert back to a tuple to use the new indices for indexing our model parameter below
    val_idx = tuple(val_idx_gen)
    data_perturb[val_idx] += random.choice((-1, 1)) * dx
    return data_perturb


def perturb_input_blur(data: torch.Tensor) -> torch.Tensor:
    data_blur_perturb = deepcopy(data)

    size = data.shape[2]
    d_blur = 1

    for i in range(data.shape[0]):
        smaller_data = transforms.functional.resize(transforms.functional.to_pil_image(data[i], mode=None),
                                                (size-d_blur))
        blur_data = transforms.functional.to_tensor(transforms.functional.resize(smaller_data, size))
        data_blur_perturb[i] = blur_data

    return data_blur_perturb


def perturb_input_contrast(data: torch.Tensor) -> torch.Tensor:
    data_contrast_perturb = deepcopy(data)
    size = data.shape[2]

    d_contrast = 1.0/(0.75*size) # for 32 pixels approx 0.0417

    for i in range(data.shape[0]):
        data_contrast_perturb[i] = transforms.functional.to_tensor(transforms.functional.adjust_contrast(
            transforms.functional.to_pil_image(data[i], mode=None), 1 - d_contrast))

    return data_contrast_perturb


def perturb_input_shift(data: torch.Tensor) -> torch.Tensor:
    data_shift_perturb = deepcopy(data)

    d_x = random.choice((-1, 1))
    d_y = random.choice((-1, 1))

    for i in range(data.shape[0]):
        data_shift_perturb[i] = transforms.functional.to_tensor(transforms.functional.affine(
            transforms.functional.to_pil_image(data[i], mode=None),
            angle=0, shear=0, scale=1,
            translate=(d_x, d_y)))

    return data_shift_perturb


def perturb_input_hue(data: torch.Tensor) -> torch.Tensor:
    data_hue_perturb = deepcopy(data)
    d_hue = 1.1
    for i in range(data.shape[0]):
        data_hue_perturb[i] = transforms.functional.to_tensor(transforms.functional.adjust_saturation(
            transforms.functional.to_pil_image(data[i], mode=None), d_hue))

    return data_hue_perturb


def cov(m):
    m_exp = torch.mean(m, dim=0, keepdim=True)
    x = m - m_exp
    cova = torch.matmul(x.transpose(0, 1), x) * 1 / (x.size(0) - 1)
    return cova


def match_n_dimensions(x, y):
    # adds empty dimensions to y until it has the same number of dimensions as x
    y_shape: Tuple = (y.size()[0], *(1 for _ in range(len(x.shape) - 1)))
    return y.reshape(y_shape)


def input_normalize(x: torch.Tensor, model_inputs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    assert x.size()[0] == model_inputs[0].size()[0], f"Input tensors must be batches first and have the same number " \
                                                     f"of batches. Found shapes {x.size()} and {model_inputs[0].size()}"

    input_delta: torch.Tensor = abs(model_inputs[1] - model_inputs[0]).view(model_inputs[0].size()[0], -1)
    divisor = (EPSILON + torch.mean(input_delta, dim=1, keepdim=True))
    divisor = match_n_dimensions(x, divisor)
    return x / divisor


def output_normalize(x: torch.Tensor, pred_target0: torch.Tensor) -> torch.Tensor:
    assert x.size()[0] == pred_target0.size()[0], f"Input tensors must be batches first and have the same number " \
                                                  f"of batches. Found shapes {x.size()} and {pred_target0.size()}"

    vec_target0: torch.Tensor = abs(pred_target0).view(pred_target0.size(0), -1)
    divisor = (EPSILON + torch.mean(vec_target0, dim=1, keepdim=True))
    divisor = match_n_dimensions(x, divisor)

    return x / divisor


def input_output_normalize(x: torch.Tensor, model_inputs: Tuple[torch.Tensor, torch.Tensor],
                           pred_target0: torch.Tensor) -> torch.Tensor:
    vec_target0: torch.Tensor = abs(pred_target0).view(pred_target0.size()[0], -1)
    divisor_in = (EPSILON + torch.mean(vec_target0, dim=1, keepdim=True))
    input_delta: torch.Tensor = abs(model_inputs[1] - model_inputs[0]).view(model_inputs[0].size()[0], -1)
    divisor_out = (EPSILON + torch.mean(input_delta, dim=1, keepdim=True))

    divisor = torch.mul(divisor_in, divisor_out)
    divisor = match_n_dimensions(x, divisor)
    return x / divisor


def index_module(module: torch.nn.Module, coordinates: Optional[List], location: Tuple):
    """
    Recurses through a torch.nn.module and retrieves integer indices for all submodules recursively.
    This is neccesary because automatic unpacking of using .modules() does not work for some custom modules,
    which contain submodules such as the Fire submodule from the torchvision SqueezeNet.
    :param module:
    :param coordinates:
    :param location:
    :return:
    """
    if coordinates is None:
        coordinates = []

    if len(list(module.children())) > 1:
        for idx, child in enumerate(module.children()):
            index_module(child, coordinates, location + tuple([idx]))
    else:
        coordinates.append((type(module).__name__, location))

    return coordinates


def retrieve_module_by_index(module: torch.nn.Module, index: Tuple[int]):
    """
    This function uses the indices generated by index_module to return a reference to the submodule
    targeted by index.
    :param module:
    :param index:
    :return:
    """
    for dim in index:
        module = list(module.children())[dim]

    return module


def perturb_conv_layer(model: torch.nn.Module, a: float = 0.0, b: float = 1.0, dtheta: float = 0.001):
    coordinates = index_module(model, None, ())
    conv_coordinates = [coord[1] for coord in coordinates if "Conv" in coord[0]]
    perturb_index = random.choice(conv_coordinates)

    with torch.no_grad():
        perturb_module = retrieve_module_by_index(model, perturb_index)
        for idx, param in enumerate(perturb_module.named_parameters()):
            if param[0] == "weight":
                # Select filter:
                perturb_filter_index = random.choice(range(list(perturb_module.parameters())[idx].size(0)))
                print(f"parameter list, idx: {idx}, "
                      f"params: {list(perturb_module.parameters())[idx][perturb_filter_index].size()}")

                list(perturb_module.parameters())[idx][perturb_filter_index] += random.choice((-1, 1)) * dtheta
                return model

        raise ValueError(f"Could not find weight parameter in torch.nn.Module {perturb_module}")


def intrinsic_dim(matrix):
    matrix = matrix - torch.mean(matrix, dim=1, keepdim=True)
    out = torch.matmul(matrix, matrix.transpose(0, 1)) / (matrix.size(1))
    trace_out = torch.trace(out)
    print(f"cov matrix size {out.size()}")
    u, s_new, v = torch.svd(out, compute_uv=False)
    spectral_norm = s_new[0].item()
    print(f"Trace {trace_out}, spectral norm: {spectral_norm}, frobenius: {torch.norm(out, p='fro')}")

    return trace_out / (EPSILON + spectral_norm)

