import torch
from models.moshi.models.loaders import CheckpointInfo


DEVICE = "cpu"

# Load Moshi checkpoint
ckpt = CheckpointInfo.from_hf_repo("kyutai/moshiko-pytorch-bf16")
mimi_model = ckpt.get_mimi(device=DEVICE)
mimi_model.eval()

embeddings = {}

# RVQ first
for i, layer in enumerate(mimi_model.quantizer.rvq_first.vq.layers):
    key = f"rvq_first_{i}"
    embeddings[key] = layer._codebook.embedding_sum.detach().cpu()

# RVQ rest
for i, layer in enumerate(mimi_model.quantizer.rvq_rest.vq.layers):
    key = f"rvq_rest_{i}"
    embeddings[key] = layer._codebook.embedding_sum.detach().cpu()

print(embeddings.keys())

# Save all embeddings into a single .pt file
torch.save(embeddings, "mimi_rvq_embeddings.pt")
print(f"Saved {len(embeddings)} embedding tensors to mimi_rvq_embeddings.pt")

# Prints:
# dict_keys(['rvq_first_0', 'rvq_rest_0', 'rvq_rest_1', 'rvq_rest_2', 'rvq_rest_3', 'rvq_rest_4', 'rvq_rest_5', 'rvq_rest_6', 'rvq_rest_7', 'rvq_rest_8', 'rvq_rest_9', 'rvq_rest_10', 'rvq_rest_11', 'rvq_rest_12', 'rvq_rest_13', 'rvq_rest_14', 'rvq_rest_15', 'rvq_rest_16', 'rvq_rest_17', 'rvq_rest_18', 'rvq_rest_19', 'rvq_rest_20', 'rvq_rest_21', 'rvq_rest_22', 'rvq_rest_23', 'rvq_rest_24', 'rvq_rest_25', 'rvq_rest_26', 'rvq_rest_27', 'rvq_rest_28', 'rvq_rest_29', 'rvq_rest_30'])
