import torch
from torch import nn

from theory.simplified_linear_mamba import create_ideal_model_weights
from theory.theory import create_embedding_matrix


def configure_mamba_weights(self, **weights_config):

    # notations
    V = self.args.vocab_size
    D = self.args.d_model
    device = self.device
    
    # set
    
    if (init_type := weights_config.get("init_E", None)) is not None:

        E = create_embedding_matrix(V=V, M=D, embedding_type=init_type, device=device)  # (V, D)

        # zero the embedding of 0
        if weights_config.get("init_E_0__zero", False):
            M = torch.eye(V, device=device)
            M[0, 0] = 0
            E = M @ E

        E_in = E * 1.
        E_out = E * 1.

        with torch.no_grad():
            self.embedding.weight.copy_(E_in)
            self.lm_head.weight.copy_(E_out)
    
    # freeze
    
    if weights_config.get("freeze_E", False):
        for p in self.embedding.parameters():
            p.requires_grad = False
        for p in self.lm_head.parameters():
            p.requires_grad = False


def configure_mamba_block_weights(self, **weights_config):

    # notations
    D = self.args.d_model
    D_in = self.args.d_inner
    N = self.args.d_state
    device = self.device
    args = self.args
    
    # set
    
    if (init_value := weights_config.get('init_A', None)) is not None:
        values = torch.full((D_in, N), fill_value=float(init_value), device=device)
        requires_grad = not weights_config.get('freeze_A', False)
        self.A = nn.Parameter(values, requires_grad=requires_grad)

    if (init_value := weights_config.get('init_D', None)) is not None:
        values = torch.full((D_in,), fill_value=float(init_value), device=device)
        requires_grad = not weights_config.get('freeze_D', False)
        self.D = nn.Parameter(values, requires_grad=requires_grad)

    if weights_config.get("init_P_in__identity_identity", False):
        assert D_in % D == 0
        assert args.expand == 2  # required
        I = torch.eye(D, device=device)
        W = torch.vstack([I, I])  # shape (D_in, D)
        with torch.no_grad():
            self.in_proj_x.weight.copy_(W)

    if weights_config.get("init_P_out__identity_zeros", False):
        assert D_in % D == 0
        assert args.expand == 2  # required
        I = torch.eye(D, device=device)
        Z = torch.zeros((D, D), device=device)
        W = torch.vstack([Z, I]).T
        with torch.no_grad():
            self.out_proj_y.weight.copy_(W)

    if weights_config.get("init_W_conv__identity_shift", False):
        # assert self.conv1d.bias is None
        assert args.expand == 2
        O_col = torch.ones((D,))
        Z_col = torch.zeros((D,))
        W_v = torch.concat([Z_col, O_col])
        W_k = torch.concat([O_col, Z_col])
        W_conv = torch.stack([W_k, W_v]).T.unsqueeze(1)
        with torch.no_grad():
            self.conv1d.weight.copy_(W_conv)

    if (init_type := weights_config.get("init_S", None)) is not None:
        assert args.expand == 2  # required

        I = torch.eye(D, device=device)
        Z = torch.zeros((D, D), device=device)
        M_k = torch.hstack([I, Z])  # (D, 2D)  – picks the "key" half of x_t

        S = create_embedding_matrix(V=D, M=N, embedding_type=init_type, device=device) # (D, N)
        S_B = S.T @ M_k  # (N, 2D)
        S_C = S.T @ M_k  # (N, 2D)

        with torch.no_grad():
            self.x_proj_to_B.weight.copy_(S_B)
            self.x_proj_to_C.weight.copy_(S_C)

    # freeze

    if weights_config.get("freeze_P_in", False):
        for p in self.in_proj_x.parameters():
            p.requires_grad = False
    if weights_config.get("freeze_P_out", False):
        for p in self.out_proj_y.parameters():
            p.requires_grad = False
    if weights_config.get("freeze_W", False):
        for p in self.conv1d.parameters():
            p.requires_grad = False
    if weights_config.get("freeze_S", False):
        for p in self.x_proj_to_B.parameters():
            p.requires_grad = False
        for p in self.x_proj_to_C.parameters():
            p.requires_grad = False


def set_mamba_ideal_MQAR_weights(model: nn.Module):
    """ note: doesn't freeze weights; only sets them """

    device = model.device

    V = model.V
    D = model.D
    N = model.N

    D_in = int(2 * D)

    E_in, P_in, W, S_B, S_C, P_out, E_out = create_ideal_model_weights(V=V, D=D, N=N)

    A = torch.ones((D_in, N))
    D_col = torch.zeros(D_in)


    with torch.no_grad():

        model.embedding.weight.copy_(E_in.T)
        model.lm_head.weight.copy_(E_out)

        model.layers[0].mixer.in_proj_x.weight.copy_(P_in)

        model.layers[0].mixer.conv1d.weight.copy_(W.unsqueeze(1))

        model.layers[0].mixer.A.copy_(A)
        model.layers[0].mixer.x_proj_to_B.weight.copy_(S_B)
        model.layers[0].mixer.x_proj_to_C.weight.copy_(S_C)
        model.layers[0].mixer.D.copy_(D_col)

        model.layers[0].mixer.out_proj_y.weight.copy_(P_out)


def set_mamba_ideal_MQAR_projection_weights(model: nn.Module, freeze: bool = True):

    # print(f"\nset_mamba_ideal_MQAR_projection_weights: called with {freeze=}")

    V = model.V
    D = model.D
    N = model.N

    D_in = int(2 * D)

    E_in, P_in, W, S_B, S_C, P_out, E_out = create_ideal_model_weights(V=V, D=D, N=N)

    A = torch.ones((D_in, N))
    D_col = torch.zeros(D_in)

    mixer = model.layers[0].mixer

    # set weights
    with torch.no_grad():

        mixer.in_proj_x.weight.copy_(P_in)
        mixer.out_proj_y.weight.copy_(P_out)

        mixer.conv1d.weight.copy_(W.unsqueeze(1))

        mixer.A.copy_(A)
        mixer.D.copy_(D_col)

    # freeze weights
    if freeze:
        # 1) modules: use .parameters()
        for mod in [mixer.in_proj_x, mixer.out_proj_y, mixer.conv1d]:
            for p in mod.parameters():
                p.requires_grad = False
        # 2) tensors/parameters: set requires_grad directly
        for t in [mixer.A, mixer.D]:
            # works for both nn.Parameter and plain Tensor
            # t.requires_grad_(False)
            t.detach()