
from collections import defaultdict
import torch

def reset_optimizer(optimizer, model, plugin):
    """Reset the optimizer to update the list of learnable parameters.

    .. warning::
        This function fails if the optimizer uses multiple parameter groups.

    :param optimizer:
    :param model:
    :return:
    """
    # assert len(optimizer.param_groups) == 1
    optimizer.state = defaultdict(dict)
    optimizer.param_groups[0]["params"] = list(model.parameters())
    if hasattr(plugin, 'unc1') and hasattr(plugin, 'unc2'):
        if len(optimizer.param_groups) > 1:
            remove_param_from_optimizer(optimizer, 1)
            remove_param_from_optimizer(optimizer, 1)
            if hasattr(plugin, 'unc3'):
                remove_param_from_optimizer(optimizer, 1)
                plugin.unc3 = torch.tensor([-0.7], requires_grad=True,
                                           dtype=torch.float32, device=plugin.device)
                optimizer.add_param_group({"params": plugin.unc3})

        plugin.unc1 = torch.tensor([-0.7], requires_grad=True, dtype=torch.float32, device=plugin.device)
        plugin.unc2 = torch.tensor([-0.7], requires_grad=True, dtype=torch.float32, device=plugin.device)
        optimizer.add_param_group({"params": plugin.unc1})
        optimizer.add_param_group({"params": plugin.unc2})


def update_optimizer(optimizer, old_params, new_params, reset_state=True):
    """Update the optimizer by substituting old_params with new_params.

    :param old_params: List of old trainable parameters.
    :param new_params: List of new trainable parameters.
    :param reset_state: Wheter to reset the optimizer's state.
        Defaults to True.
    :return:
    """
    for old_p, new_p in zip(old_params, new_params):
        found = False
        # iterate over group and params for each group.
        for group in optimizer.param_groups:
            for i, curr_p in enumerate(group["params"]):
                if hash(curr_p) == hash(old_p):
                    # update parameter reference
                    group["params"][i] = new_p
                    found = True
                    break
            if found:
                break
        if not found:
            raise Exception(
                f"Parameter {old_params} not found in the "
                f"current optimizer."
            )
    if reset_state:
        # State contains parameter-specific information.
        # We reset it because the model is (probably) changed.
        optimizer.state = defaultdict(dict)


def add_new_params_to_optimizer(optimizer, new_params):
    """Add new parameters to the trainable parameters.

    :param new_params: list of trainable parameters
    """
    optimizer.add_param_group({"params": new_params})

def remove_param_from_optimizer(optim, pg_index):
    # Remove corresponding state
    for param in optim.param_groups[pg_index]['params']:
        if param in optim.state:
            del optim.state[param]
    del optim.param_groups[pg_index]
