import torch
from transformers import MimiModel as HF_MimiModel
from models.moshi.models import loaders


DEVICE = "cpu"

# --------------------------------------------------
# Load official HuggingFace Mimi
# --------------------------------------------------
print("Loading official HuggingFace Mimi...")
hf_mimi = HF_MimiModel.from_pretrained("kyutai/mimi").to(DEVICE).eval()

# --------------------------------------------------
# Load Moshi Mimi
# --------------------------------------------------
print("Loading Moshi Mimi via CheckpointInfo...")
ckpt = loaders.CheckpointInfo.from_hf_repo("kyutai/moshiko-pytorch-bf16")
moshi_mimi = ckpt.get_mimi(device=DEVICE)

finetuned_mimi = loaders.get_mimi("/home/wmar/wmar_audio/checkpoints/finetunes/mimi_ft_noaug.pth", DEVICE)
finetuned_mimi_aug = loaders.get_mimi("/home/wmar/wmar_audio/checkpoints/finetunes/mimi_ft.pth", DEVICE)

# --------------------------------------------------
# Define top-level modules to compare
# --------------------------------------------------
modules_to_compare = {
    "HF Mimi VS Moshi Mimi": [
        ("encoder", hf_mimi.encoder, moshi_mimi.encoder),
        ("decoder", hf_mimi.decoder, moshi_mimi.decoder),
        ("encoder_transformer", hf_mimi.encoder_transformer, moshi_mimi.encoder_transformer),
        ("decoder_transformer", hf_mimi.decoder_transformer, moshi_mimi.decoder_transformer),
        ("quantizer", hf_mimi.quantizer, moshi_mimi.quantizer)
    ],
    "Finetuned Mimi VS HF Mimi": [
        ("encoder", finetuned_mimi.encoder, hf_mimi.encoder),
        ("decoder", finetuned_mimi.decoder, hf_mimi.decoder),
        ("encoder_transformer", finetuned_mimi.encoder_transformer, hf_mimi.encoder_transformer),
        ("decoder_transformer", finetuned_mimi.decoder_transformer, hf_mimi.decoder_transformer),
        ("quantizer", finetuned_mimi.quantizer, hf_mimi.quantizer)
    ],
    "Finetuned Mimi VS Moshi Mimi": [
        ("encoder", finetuned_mimi.encoder, moshi_mimi.encoder),
        ("decoder", finetuned_mimi.decoder, moshi_mimi.decoder),
        ("encoder_transformer", finetuned_mimi.encoder_transformer, moshi_mimi.encoder_transformer),
        ("decoder_transformer", finetuned_mimi.decoder_transformer, moshi_mimi.decoder_transformer),
        ("quantizer", finetuned_mimi.quantizer, moshi_mimi.quantizer)
    ],
    "Finetuned Mimi (aug) VS Moshi Mimi": [
        ("encoder", finetuned_mimi_aug.encoder, moshi_mimi.encoder),
        ("decoder", finetuned_mimi_aug.decoder, moshi_mimi.decoder),
        ("encoder_transformer", finetuned_mimi_aug.encoder_transformer, moshi_mimi.encoder_transformer),
        ("decoder_transformer", finetuned_mimi_aug.decoder_transformer, moshi_mimi.decoder_transformer),
        ("quantizer", finetuned_mimi_aug.quantizer, moshi_mimi.quantizer)
    ],
}

# --------------------------------------------------
# Compare modules
# --------------------------------------------------
def compare_module(name, module1, module2, atol=1e-6):
    params1 = list(module1.parameters())
    params2 = list(module2.parameters())

    if len(params1) != len(params2):
        print(f"[MISMATCH] {name}: different number of parameters {len(params1)} vs {len(params2)}")
        return

    exact = 0
    close = 0
    different = 0

    for i, (p1, p2) in enumerate(zip(params1, params2)):
        if p1.shape != p2.shape:
            print(f"[SHAPE] {name} param {i}: {tuple(p1.shape)} vs {tuple(p2.shape)}")
            different += 1
            continue

        diff = (p1.detach().cpu() - p2.detach().cpu()).abs()
        max_diff = diff.max().item()

        if max_diff == 0.0:
            exact += 1
        elif max_diff < atol:
            close += 1
        else:
            # print(f"[DIFF] {name} param {i}: max |Δ| = {max_diff:.6e}")
            different += 1

    print(f"[{name}] exact={exact}, close={close}, different={different}")


print("\n=== MODULE-LEVEL COMPARISON ===")
for name, setting in modules_to_compare.items():
    print(name)
    for mod_name, mod_hf, mod_mo in setting:
        compare_module(mod_name, mod_hf, mod_mo)
    print()
