import torch
import math


def clone_and_replace_matrix(
    base_sd: dict,
    target_sd: dict,
    bkey: str,
    tkey: str,
    snr_db=None,
    normalize=False
):
    """
    A simplified function-preserving expansion with random offsets. 
    Handles 2× expansions in rows (Case 1), columns (Case 2), or both (Case 3).
    The difference from your older code:
      - We optionally use `normalize=True` to scale the final expanded blocks 
        by sqrt(2) (for Cases 1 & 2) or by 2 (for Case 3).
      - Case 3 uses two offsets (offset1, offset2) to break symmetry in both dimensions.
    """
    W_src = base_sd[bkey]
    W_dst_shape = target_sd[tkey].shape
    old_rows, old_cols = W_src.shape
    new_rows, new_cols = W_dst_shape

    # ------------------------------------------------------------------
    # Helpers to generate random offsets given snr_db
    # ------------------------------------------------------------------
    def _generate_offset(weight, snr_db):
        if snr_db is None:
            return torch.zeros_like(weight)
        # E.g., an existing function that uses SNR to decide noise amplitude
        return get_noise_with_snr(weight, snr_db)

    # ------------------------------------------------------------------
    # CASE 1: Doubling rows only
    # ------------------------------------------------------------------
    if new_rows == 2 * old_rows and new_cols == old_cols:
        print(f"[hyperclone] Expanding rows (Case 1) with random offsets: {bkey}")
        with torch.no_grad():
            half_weight = W_src / 2.0
            offset = _generate_offset(half_weight, snr_db)

            top_block = half_weight + offset
            bot_block = half_weight - offset
            expanded = torch.cat([top_block, bot_block], dim=0)

            if normalize:
                # Optionally scale by sqrt(2) so that 
                # top_block + bot_block roughly matches original W in norm
                expanded /= math.sqrt(2.0)

            target_sd[tkey] = expanded.to(W_src.device)
        return

    # ------------------------------------------------------------------
    # CASE 2: Doubling columns only
    # ------------------------------------------------------------------
    if new_cols == 2 * old_cols and new_rows == old_rows:
        print(f"[hyperclone] Expanding columns (Case 2) with random offsets: {bkey}")
        with torch.no_grad():
            half_weight = W_src / 2.0
            offset = _generate_offset(half_weight, snr_db)

            left_block  = half_weight + offset
            right_block = half_weight - offset
            expanded = torch.cat([left_block, right_block], dim=1)

            if normalize:
                # Similar scale by sqrt(2)
                expanded /= math.sqrt(2.0)

            target_sd[tkey] = expanded.to(W_src.device)
        return

    # ------------------------------------------------------------------
    # CASE 3: Doubling both rows and columns
    # ------------------------------------------------------------------
    if (new_rows == 2 * old_rows) and (new_cols == 2 * old_cols):
        print(f"[hyperclone] Expanding rows & cols (Case 3) with two offsets: {bkey}")

        with torch.no_grad():
            # We'll treat everything as W/2, then add/sub offsets.
            # Use two offsets to break symmetry in both directions:
            half_weight = W_src / 2.0
            offset1 = _generate_offset(half_weight, snr_db)
            offset2 = _generate_offset(half_weight, snr_db)

            # Form the 2×2 blocks:
            top_left  = half_weight + offset1 + offset2
            top_right = half_weight + offset1 - offset2
            bot_left  = half_weight - offset1 + offset2
            bot_right = half_weight - offset1 - offset2

            top_cat = torch.cat([top_left, top_right], dim=1)
            bot_cat = torch.cat([bot_left, bot_right], dim=1)
            expanded = torch.cat([top_cat, bot_cat], dim=0)

            if normalize:
                # If you want each sub‐block to total about the same norm as half_weight, 
                # dividing by 2.0 can help. Another approach is dividing by sqrt(4)=2 
                # so that combining 4 blocks keeps the total norm similar to W.
                expanded /= 2.0

            target_sd[tkey] = expanded.to(W_src.device)
        return


def clone_and_replace_vector(
    base_sd: dict,
    target_sd: dict,
    bkey: str,
    tkey: str,
    snr_db=None,
    **kwargs,
) -> None:
    """
    Clones a 1D vector from base_sd[bkey] into dst_sd[dkey].

    For simplicity, we do the direct expansion with a minimal 
    approach or direct copy if shapes match.

    Args:
        base_sd: State dict of the smaller/base model.
        dst_sd: State dict of the larger/target model.
        bkey, dkey: Param names in base_sd, dst_sd respectively.
        snr_db: If not None, add noise logic as needed.
    """
    if bkey not in base_sd or tkey not in target_sd:
        return
    v_src = base_sd[bkey]
    dst_shape = target_sd[tkey].shape
    if len(v_src.shape) == 1 and len(dst_shape) == 1:
        # If you want partial expansions or noise, could do:
        # repeated = clone_matrix((dst_shape[0],1), v_src.unsqueeze(1), snr_db=snr_db)
        # dst_sd[dkey] = repeated.squeeze(1)
        # or a direct approach if shapes match or integer multiple:
        if v_src.shape[0] == dst_shape[0]:
            # direct copy
            target_sd[tkey] = v_src.clone()
        else:
            # example partial expansion
            out = torch.zeros(dst_shape, dtype=v_src.dtype, device=v_src.device)
            factor = dst_shape[0] // v_src.shape[0]
            repeated = v_src.repeat(factor)[: dst_shape[0]]
            out[: len(repeated)] = repeated
            # Add noise if needed
            if snr_db is not None and out.ndim == 2:
                out = add_noise(out, v_src.shape, snr_db = snr_db)
            target_sd[tkey] = out
    else:
        raise ValueError(f"Invalid shapes for {bkey} and {tkey}: {v_src.shape} and {dst_shape}")

def clone_positional_embedding(
    base_sd: dict,
    target_sd: dict,
    bkey: str = "transformer.wpe.weight",
    tkey: str = "transformer.wpe.weight",
    snr_db=None,
    **kwargs,
) -> None:
    """
    Clones 'transformer.wpe.weight' from base to destination, repeating columns
    if needed.

    Args:
        base_sd, dst_sd: The base & destination state dicts.
        bkey, dkey: The keys for the pos embedding.
        snr_db: If you want to inject noise, pass snr_db.

    Returns:
        None.
    """
    if bkey not in base_sd or tkey not in target_sd:
        return
    src_wpe = base_sd[bkey]
    dst_wpe = target_sd[tkey]
    assert src_wpe.shape[0] == dst_wpe.shape[0], "block_size must match for wpe"

    # Repeat columns if the dimension is bigger
    col_repeat = dst_wpe.shape[1] // src_wpe.shape[1]
    if col_repeat > 1:
        expanded = src_wpe.repeat(1, col_repeat)
        # Add noise if needed
        if snr_db is not None:
            expanded = add_noise(expanded, src_wpe.shape, snr_db)
        target_sd[tkey] = expanded
    else:
        # same shape or smaller? direct copy or partial expansions
        target_sd[tkey] = src_wpe.clone()

def scale_linear_layer(layer: torch.nn.Linear, scaler: float):
    """
    Scales the parameters of 'layer' so that its output is multiplied by 'scaler'.

    Arguments:
        layer:
            Linear layer to be scaled.
        scaler:
            Value to multiply the layer output.

    Returns:
        None.
    """
    layer.weight.data *= scaler
    if layer.bias is not None:
        layer.bias.data *= scaler


def get_noise_with_snr(weight: torch.tensor, snr_db: float):
    """
    Gaussian noise to be added to 'weight' so that the signal-to-noise
    ratio becomes 'snr_db'.

    Arguments:
        weight:
            Signal tensor.
        snr_db:
            Signal-to-noise ratio in decibels.

    Returns:
        Noise tensor.
    """
    signal_power = torch.mean(weight**2)
    snr_linear = 10 ** (snr_db / 10)
    noise_power = signal_power / snr_linear
    noise = torch.randn_like(weight)
    current_noise_power = torch.mean(noise**2)
    noise = noise * torch.sqrt(noise_power / current_noise_power)
    return noise.to(weight.dtype)


def add_noise(weight, block_shape, snr_db):
    """
    Repeatedly adds and subtracts noise to 'block_shape' blocks within 'weight'.

    The noise is applied in alternating blocks of 'block_shape'.
    Below are several illustrations:

    Examples 1 & 2, even repetition of columns:
    +-------+-------+        +-------+-------+
    |   W   |   W   |        | W+N1  | W-N1  |
    +-------+-------+   -->  +-------+-------+
    |   W   |   W   |        | W+N2  | W-N2  |
    +-------+-------+        +-------+-------+

    +-------+-------+-------+-------+        +-------+-------+-------+-------+
    |   W   |   W   |   W   |   W   |        | W+N1  | W-N1  | W+N2  | W-N2  |
    +-------+-------+-------+-------+   -->  +-------+-------+-------+-------+
    |   W   |   W   |   W   |   W   |        | W+N3  | W-N3  | W+N4  | W-N4  |
    +-------+-------+-------+-------+        +-------+-------+-------+-------+

    Example 3, odd repetition of columns:
    +-------+-------+-------+        +-------+-------+-------+
    |   W   |   W   |   W   |        | W+N1  | W-N1  |   W   |
    +-------+-------+-------+   -->  +-------+-------+-------+
    |   W   |   W   |   W   |        | W+N2  | W-N2  |   W   |
    +-------+-------+-------+        +-------+-------+-------+

    Arguments:
        weight:
            Signal tensor.
        block_shape:
            Shape of the block to which noise is added or subtracted.
        snr_db:
            Signal-to-noise ratio in decibels.

    Returns:
        Noisy weight.
    """
    assert weight.shape[0] % block_shape[0] == 0
    assert weight.shape[1] % block_shape[1] == 0
    n_repeat_0 = weight.shape[0] // block_shape[0]
    n_repeat_1 = weight.shape[1] // block_shape[1]
    if weight.ndim == 2:
        for n0 in range(n_repeat_0):
            start0 = n0 * block_shape[0]
            end0 = start0 + block_shape[0]
            for n1 in range(n_repeat_1 // 2):
                start1 = 2 * n1 * block_shape[1]
                end1 = start1 + block_shape[1]
                start2 = (2 * n1 + 1) * block_shape[1]
                end2 = start2 + block_shape[1]
                noise = get_noise_with_snr(weight[start0:end0, start1:end1], snr_db)
                weight[start0:end0, start1:end1] += noise
                weight[start0:end0, start2:end2] -= noise
        return weight
    else:
        for n0 in range(weight.shape[0]):
            weight[n0] = add_noise(weight[n0], block_shape[1:], snr_db)
        return weight


def clone_matrix(dst_weight_shape, src_weight, snr_db=None, normalize=True):
    """
    Clones a matrix from 'src_weight' into 'dst_weight_shape'.

    Arguments:
        dst_weight_shape:
            Shape of the destination matrix. Must divide
            src_weight.shape.
        src_weight:
            Source weight to be cloned.
        snr_db:
            Signal-to-noise ratio in case noise is to be added.
            Defaults to None (no noise added).
        normalize:
            If True, normalize the weight by the number of repetitions
            in the second dimension.

    Returns:
        Cloned matrix with shape 'dst_weight_shape'.
    """
    out_features_old, in_features_old = src_weight.shape
    out_features_new, in_features_new = dst_weight_shape
    assert out_features_new >= out_features_old
    assert out_features_new % out_features_old == 0
    assert in_features_new >= in_features_old
    assert (
        in_features_new % in_features_old == 0
    ), f"{in_features_new} does not divide {in_features_old}"
    n_repeat_0 = out_features_new // out_features_old
    n_repeat_1 = in_features_new // in_features_old

    dst_weight = src_weight.data.repeat(n_repeat_0, n_repeat_1)
    if normalize:
        dst_weight = dst_weight / n_repeat_1
    if snr_db is not None:
        dst_weight = add_noise(dst_weight, src_weight.shape, snr_db)
    return dst_weight


global_params_map = {
    "embedding": clone_and_replace_matrix,
    "unembedding": clone_and_replace_matrix,
}
attn_params_map = {
    "attn_in": clone_and_replace_matrix, 
    "attn_in_bias": clone_and_replace_vector,
    "attn_out": clone_and_replace_matrix,
    "attn_out_bias": clone_and_replace_vector,
}
mlp_params_map = {
    "mlp_in": clone_and_replace_matrix,
    "mlp_in_bias": clone_and_replace_vector,
    "mlp_out": clone_and_replace_matrix,
    "mlp_out_bias": clone_and_replace_vector,
}