from copy import copy

import torch

from theory.theory import create_embedding_matrix, EmbeddingType


def set_ideal_MQAR_model_weights(model, V, D, N, embedding_type: EmbeddingType):
    # generate weight matrices

    expand = 2
    D_in = int(D * expand)

    E = create_embedding_matrix(V=V, M=D, embedding_type=embedding_type)  # (V, D)
    S = create_embedding_matrix(V=D, M=N, embedding_type=embedding_type)  # (D, N)

    M = torch.eye(V)
    M[0, 0] = 0
    E = M @ E

    I_D = torch.eye(D)
    Z_D = torch.zeros((D, D))

    I_D_col = torch.ones(D)
    Z_D_col = torch.zeros(D)

    Z_DN = torch.zeros((D, N))

    I_N = torch.eye(N)

    P_k = copy(I_D)
    P_v = copy(I_D)

    E_in = copy(E)
    E_out = copy(E)

    P_in = torch.vstack([P_k, P_v])
    P_out = torch.vstack([Z_D, P_v])

    S_B = torch.vstack([S, Z_DN])
    S_C = torch.vstack([S, Z_DN])

    W_k = torch.vstack([I_D_col, Z_D_col])
    W_v = torch.vstack([Z_D_col, I_D_col])
    W = torch.hstack([W_k, W_v])

    # A = copy(I_N)  # TODO
    A = torch.ones((D_in, N))

    D_col = torch.hstack([Z_D_col, Z_D_col])

    # set model weights

    with torch.no_grad():

        model.embedding.weight.copy_(E_in)
        model.lm_head.weight.copy_(E_out)

        model.layers[0].mixer.in_proj_x.weight.copy_(P_in)
        model.layers[0].mixer.out_proj_y.weight.copy_(P_out.T)

        model.layers[0].mixer.x_proj_to_B.weight.copy_(S_B.T)
        model.layers[0].mixer.x_proj_to_C.weight.copy_(S_C.T)

        model.layers[0].mixer.conv1d.weight.copy_(W.T.unsqueeze(1))

        model.layers[0].mixer.A.copy_(A)
        model.layers[0].mixer.D.copy_(D_col)
