from pathlib import Path
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F

@torch.no_grad()
def init_embeddings_normal(
    model,
    old_vocab_size,
    average_per_dim=False,
):
    embeddings = model.model.embed_tokens.weight.data
    lm_head = model.lm_head.weight.data

    d_model = embeddings.shape[1]
    new_vocab_size = embeddings.shape[0]
    
    if average_per_dim:
        mean, std = embeddings[:old_vocab_size].mean(dim=0), embeddings[:old_vocab_size].std(dim=0)
    else:
        mean, std = embeddings[:old_vocab_size].mean(), embeddings[:old_vocab_size].std()
    mean = mean.unsqueeze(dim=0).expand(new_vocab_size - old_vocab_size, d_model)
    std = std.unsqueeze(dim=0).expand(new_vocab_size - old_vocab_size, d_model)
    embeddings[old_vocab_size:] = torch.normal(mean=mean, std=std).to(embeddings.dtype)

    if average_per_dim:
        mean, std = lm_head[:old_vocab_size].mean(dim=0), lm_head[:old_vocab_size].std(dim=0)
    else:
        mean, std = lm_head[:old_vocab_size].mean(), lm_head[:old_vocab_size].std()
    mean = mean.unsqueeze(dim=0).expand(new_vocab_size - old_vocab_size, d_model)
    std = std.unsqueeze(dim=0).expand(new_vocab_size - old_vocab_size, d_model)
    lm_head[old_vocab_size:] = torch.normal(mean=mean, std=std).to(embeddings.dtype)


def draw(old_embeddings, new_embeddings, mode, save):
    if not save:
        return

    plt.figure()
    plt.title(f"{mode} mode: old embeddings[:, :128]")
    plt.xlabel("d_model[:128]")
    plt.ylabel("vocab_size")
    plt.imshow(old_embeddings[:, :128].to(torch.float16).numpy(), aspect="auto")
    plt.savefig(f"{mode}-old-embeddings.png")

    plt.figure()
    plt.title(f"{mode} mode: new embeddings[:, :128]")
    plt.xlabel("d_model[:128]")
    plt.ylabel("vocab_size")
    plt.imshow(new_embeddings[:, :128].to(torch.float16).numpy(), aspect="auto")
    plt.savefig(f"{mode}-new-embeddings.png")

def l2norm(t):
    return F.normalize(t, p = 2, dim = -1)


class CodebookEmbedding(nn.Module):
    def __init__(self, num_tokens, codebook_dim):
        super().__init__()
        self.num_tokens = num_tokens
        self.codebook_dim = codebook_dim
        weight = torch.randn(num_tokens, codebook_dim)
        weight = l2norm(weight)
        self.weight = nn.Parameter(weight)
        
    def forward(self, embed_id):
        return F.embedding(embed_id, self.weight)