import torch
import logging
from ..layers import AdaptOp, MlpNewOp, SkipOp


_logger = logging.getLogger(__name__)

@torch.no_grad()
def init_weights(m):
    if type(m).__name__ in ["Linear", "Conv2d"]:
        m.weight.fill_(1.0)


def new(layer, transfer_backbone=True, new_init_scale=None, **kwargs):
    use_bias = layer.fc1.bias is not None
    if transfer_backbone:
        _logger.info("Transfering backbone weights")
        init_mean_weights
        init_mean_weights = {
            "fc1": layer.fc1.weight,
            "fc2": layer.fc2.weight
        }
        init_mean_bias = {
            "fc1": layer.fc1.bias,
            "fc2": layer.fc2.bias
        }
    else:
        init_mean_weights = None
        init_mean_bias = None
        new_init_scale = None
    op = MlpNewOp(layer.fc1.in_features, layer.fc1.out_features, layer.fc2.out_features, use_bias, init_mean_weights=init_mean_weights, init_mean_bias=init_mean_bias, init_scale=new_init_scale, **kwargs)
    init = kwargs.get("init", None)
    if init is not None:
        if init == "ones":
            op.apply(init_weights)
        else:
            raise ValueError(f"Init {init} not supported")
    return op


def skip(layer, **kwargs):

    op = SkipOp(layer.fc1.in_features, layer.fc2.out_features, **kwargs)

    return op


def adapt(layer, downscale, residual=True, small_init=False, adapt_init_scale=1e-3, **kwargs):
    op = AdaptOp(layer.fc2.out_features, downscale, bias=True, residual=residual, small_init=small_init, init_scale=adapt_init_scale, **kwargs)
    init = kwargs.get("init", None)
    if init is not None:
        if init == "ones":
            op.apply(init_weights)
        else:
            raise ValueError(f"Init {init} not supported")
    return op


OP_FACTORIES = {
    "new": new, "adapt": adapt, "reuse": None, "skip": skip, "identity": None
}
