from copy import deepcopy

import torch
import numpy as np

from path_learning.utils.log import get_logger

logger = get_logger("learning_intervention")


def intervention(task=None, model=None, logger=None, device='cuda', kwargs=None):
    """
    This function "intervenes" on the model by changing model parameters.
    Different options exist such as adding random noise or adding 'identity'-like weights
    to existing weights.
    :param task:
    :param model:
    :param logger:
    :param device:
    :param kwargs:
    :return:
    """
    # Get parameters
    if task is None:
        parameter_dict: dict = kwargs

    else:
        parameter_dict: dict = task.config

    # Standard intervention
    identity_scaling = parameter_dict.get("identity_scaling", 0.05)
    rand_scaling = parameter_dict.get("rand_scaling", 0.01)
    bn_reset = parameter_dict.get("bn_reset", False)

    make_identity = parameter_dict.get("make_identity", False)
    make_identity_w_symmetry_breaking = parameter_dict.get("make_identity_w_symmetry_breaking", False)
    make_identity_w_symmetry_breaking_no_bias = parameter_dict.get("make_identity_w_symmetry_breaking_no_bias", False)
    randomize_model = parameter_dict.get("randomize_model", False)
    reset_head = parameter_dict.get("reset_head", False)

    # Weight "freezing" - setting to 0
    make_frozen = parameter_dict.get("make_frozen", False)
    make_single_frozen = parameter_dict.get("make_single_frozen", False)
    make_specific_frozen = parameter_dict.get("make_specific_frozen", False)
    make_identity_only_bn = parameter_dict.get("make_identity_only_bn", False)

    if make_identity:
        logger.info(f"Moving weights towards identity to re-enable learning.")
        for name, param in model.named_parameters():
            if "weight" in name:
                logger.info(f"Scaling of weight {name} with size {param.size()} is {identity_scaling}")
                if len(param.size()) >= 3:
                    w = torch.empty(param.data.size()).to(device)
                    torch.nn.init.dirac_(w)
                elif len(param.size()) == 2:
                    w = torch.empty(param.data.size()).to(device)
                    torch.nn.init.eye_(w)
                elif len(param.size()) == 1:
                    w = 1.0 * torch.ones_like(param.data).to(device)
                    logger.info(f"Weight change values before scaling: {w[:20]}, param.data: {param.data[:20]}")
                    norm_before = torch.norm(param.data).detach()
                    param.data = w
                    logger.info(f"Norm changes of weight {name} is "
                                f"{abs(torch.norm(param.data) - norm_before)}, "
                                f"norm before: {norm_before}")
                    continue
                else:
                    logger.info(f"Skipping weight: {name}")
                    continue
                norm_before = torch.norm(param.data).detach()
                param.data = param.data + identity_scaling * w
                logger.info(f"Norm changes of weight {name} is "
                            f"{abs(torch.norm(param.data) - norm_before)}, "
                            f"norm before: {norm_before}")

    if bn_reset:
        logger.info(f"Moving weights towards identity to re-enable learning.")
        for name, param in model.named_parameters():
            if "weight" in name:
                if len(param.size()) == 1:
                    w = 1.0 * torch.ones_like(param.data).to(device)
                    logger.info(f"Weight change values before scaling: {w[:20]}, param.data: {param.data[:20]}")
                    norm_before = torch.norm(param.data).detach()
                    param.data = w
                    logger.info(f"Norm changes of weight {name} is "
                                f"{abs(torch.norm(param.data) - norm_before)}, "
                                f"norm before: {norm_before}")
                else:
                    logger.info(f"Skipping weight: {name}")
                    continue

    if make_identity_w_symmetry_breaking:
        logger.info(f"Moving weights towards identity to re-enable learning.")
        for name, param in model.named_parameters():
            if "weight" in name:
                logger.info(f"Scaling of weight {name} with size {param.size()} is {identity_scaling}"
                            f"and rand scaling: {rand_scaling}")
                if len(param.size()) >= 3:
                    w = torch.empty(param.data.size()).to(device)
                    torch.nn.init.dirac_(w)
                    w += rand_scaling * torch.randn_like(w.data)
                elif len(param.size()) == 2:
                    w = torch.empty(param.data.size()).to(device)
                    torch.nn.init.eye_(w)
                    w += rand_scaling * torch.randn_like(w.data)
                elif len(param.size()) == 1:
                    # Set to 1 entries that are smaller than 1
                    w = 1.0 * torch.ones_like(param.data).to(device)  # \
                    logger.info(f"Weight change values before scaling: {w[:20]}, param.data: {param.data[:20]}")
                    norm_before = torch.norm(param.data).detach()
                    param.data = w
                    logger.info(f"Norm changes of weight {name} is "
                                f"{abs(torch.norm(param.data) - norm_before)}, "
                                f"norm before: {norm_before}")
                    continue
                else:
                    logger.info(f"Skipping weight: {name}")
                    continue
                norm_before = torch.norm(param.data).detach()
                param.data = param.data + identity_scaling * w
                logger.info(f"Norm changes of weight {name} is "
                            f"{abs(torch.norm(param.data) - norm_before)}, "
                            f"norm before: {norm_before}")

    if make_identity_w_symmetry_breaking_no_bias:
        identity_scaling = task.config.get("identity_scaling", 0.05)
        rand_scaling = task.config.get("rand_scaling", 0.01)
        logger.info(f"Moving weights towards identity to re-enable learning.")
        for name, param in model.named_parameters():
            if "weight" in name:
                logger.info(f"Scaling of weight {name} with size {param.size()} is {identity_scaling}")
                if len(param.size()) >= 3:
                    w = torch.empty(param.data.size()).to(device)
                    torch.nn.init.dirac_(w)
                    w += rand_scaling * torch.randn_like(w.data)
                elif len(param.size()) == 2:
                    w = torch.empty(param.data.size()).to(device)
                    torch.nn.init.eye_(w)
                    w += rand_scaling * torch.randn_like(w.data)
                elif len(param.size()) == 1:
                    w = torch.empty(param.data.size()).to(device)
                    w.data.fill_(1.0)
                    w += rand_scaling * torch.randn_like(w.data)
                else:
                    logger.info(f"Skipping weight: {name}")
                    continue
                norm_before = torch.norm(param.data).detach()
                param.data = param.data + identity_scaling * w
                logger.info(f"Norm changes of weight {name} is "
                            f"{abs(torch.norm(param.data) - norm_before)}, "
                            f"norm before: {norm_before}")

    if make_frozen:
        drop_row_col_pair_prob = task.config.get("drop_row_col_pair_prob", 0.5)
        logger.info(f"Freezing some weights to reduce learningability.")
        model_dict = model.state_dict()
        model_dict_ref = deepcopy(model_dict)
        names = [key for key in model_dict.keys() if "weight" in key and "shortcut" not in key]
        indices = []
        for i, name in enumerate(names):
            if i == (len(names) - 3):
                # We only need n-1 pairwise comparisons,
                # i==len(names)-1 leads to error
                break
            assert names[i] == name
            current_param = model_dict[name]
            next_param = model_dict[names[i + 1]]
            next_next_param = model_dict[names[i + 2]]

            # Randomly select row
            for j in range(current_param.size(0)):
                rand_select = int(np.random.binomial(1, drop_row_col_pair_prob, 1))
                # Set row and associated multipliers and weights to zero
                if rand_select:
                    indices.append(f"{name}-" + str(j))
                    current_param[j] = 0
                    if len(next_param.size()) > 1:
                        next_param[:, j] = 0
                    else:
                        next_param[j] = 0
                    if len(next_next_param.size()) > 1:
                        next_next_param[:, j] = 0
                    else:
                        next_next_param[j] = 0

            norm_after = torch.norm(current_param.data).detach()
            logger.info(f"Norm changes of weight {name} is "
                        f"{abs(torch.norm(model_dict_ref[name]) - norm_after)}, "
                        f"norm before: {torch.norm(model_dict_ref[name])}")

        logger.info(f"Columns set to zero: {indices}")

    if make_single_frozen:
        drop_prob = task.config.get("drop_single_prob", 0.5)
        logger.info(f"Freezing non-paired weights to reduce learning ability.")
        model_dict = model.state_dict()
        model_dict_ref = deepcopy(model_dict)
        names = [key for key in model_dict.keys() if "weight" in key and "shortcut" not in key and "bn" not in key]
        for i, name in enumerate(names):
            current_param = model_dict[name]
            binary_mask = torch.bernoulli(drop_prob * torch.ones_like(current_param))
            current_param = current_param * binary_mask

            norm_after = torch.norm(current_param.data).detach()
            logger.info(f"Norm changes of weight {name} is "
                        f"{abs(torch.norm(model_dict_ref[name]) - norm_after)}, "
                        f"norm before: {torch.norm(model_dict_ref[name])}")

    if make_specific_frozen:
        logger.info(f"For visualization, set specific weights to zero. "
                    f"Freezing non-paired weights to reduce learning ability.")
        scale = task.config.get("scale", 0.0)
        for name, param in model.named_parameters():
            if "weight" in name:
                norm_before = torch.norm(param.data).detach()
                if len(param.data.size()) > 1:
                    # For analysis comparison we only set the first two filters to be "frozen"
                    param.data[:2] = scale * torch.randn_like(param.data[:2])
                    param.data[:, :2] = scale * torch.randn_like(param.data[:, :2])
                else:
                    param.data[:2] = scale * torch.randn_like(param.data[:2])
                logger.info(f"Norm changes of weight {name} is "
                            f"{abs(torch.norm(param.data) - norm_before)}, "
                            f"norm before: {norm_before}")

    if make_identity_only_bn:
        identity_scaling = task.config.get("identity_scaling", 0.05)
        logger.info(f"Moving weights towards identity to re-enable learning.")
        for name, param in model.named_parameters():
            if "weight" in name:
                logger.info(f"Scaling of weight {name} with size {param.size()} is {identity_scaling}")
                if len(param.size()) == 1:
                    w = 1.0 * torch.ones_like(param.data).to(device)
                    norm_before = torch.norm(param.data).detach()
                    param.data = param.data + w
                    logger.info(f"Norm changes of weight {name} is "
                                f"{abs(torch.norm(param.data) - norm_before)}, "
                                f"norm before: {norm_before}")
                else:
                    continue
                norm_before = torch.norm(param.data).detach()
                logger.info(f"Norm changes of weight {name} is "
                            f"{abs(torch.norm(param.data) - norm_before)}, "
                            f"norm before: {norm_before}")

    if reset_head:
        logger.info(f"Resetting last layer of model to random initialization.")
        model.reset_head()

    if randomize_model:
        logger.info(f"Adding noise to all weights to re-enable learning.")
        for name, param in model.named_parameters():
            scaling = task.config.get("random_scaling", 0.1)
            if "weight" in name:
                logger.info(f"Scaling of weight {name} is {scaling}")
                norm_before = torch.norm(param.data).detach()
                param.data = param.data + scaling * torch.randn_like(param.data)
                logger.info(f"Norm changes of weight {name} is "
                            f"{abs(torch.norm(param.data) - norm_before)}")
    return model
