import regex as re


def do_transpose(key_s):
    if re.search(r"mlp.fc[0-9]+.kernel", key_s):
        return True
    if re.search(r"down_proj.kernel", key_s):
        return True
    if re.search(r"up_proj.kernel", key_s):
        return True
    if re.search(r"gate_proj.kernel", key_s):
        return True
    if re.search(r"linear.kernel", key_s):
        return True
    if re.search(r"self_attn.[a-z]+_proj.kernel", key_s):
        return True
    if re.search(r"lm_head.kernel", key_s):
        return True
    return False


def apply_trans(MAPPING, key_s, pt_state):
    get_torch_layer = lambda k, nr: k.replace(".{x}.", f".{nr}.") if nr else k
    """Expects flattened np_state and torch pt_state"""
    # transform layer nameing
    layer_nr = None
    if "residual_block_" in key_s:
        res = re.search(r"residual_block_[0-9]+", key_s).group(0)
        layer_nr = res.split("_")[-1]
        mapping_key = key_s.replace(res, "residual_block_{x}")
    else:
        mapping_key = key_s

    if not mapping_key in MAPPING:
        return None, None
    y = MAPPING[mapping_key]
    # # apply transform (if any) in jax
    # if isinstance(y, tuple):
    #     transform = y[1]
    #     torch_key = get_torch_layer(y[0], layer_nr)
    #     y = pt_state[torch_key]
    #     tmp = transform(y)
    # same transform for all Dense layers
    transpose = do_transpose(key_s)
    if transpose:
        torch_key = get_torch_layer(y, layer_nr)
        tmp = pt_state[torch_key]
        tmp = tmp.T
    else:
        torch_key = get_torch_layer(y, layer_nr)
        tmp = pt_state[torch_key]

    print(
        "Original key:",
        key_s,
        "Jax layer nr: ",
        layer_nr,
        "torch layer nr: ",
        torch_key,
        "transpose:",
        transpose,
    )
    return tmp, torch_key
