"""
Minimal script that:
1) Downloads & saves the official EnCodec checkpoint to checkpoints/
2) Loads MusicGen processor + model from HF, swaps the audio feature-extractor with the saved EnCodec files
3) Replaces the model.audio_encoder with the saved local EnCodec model
4) Generates audio and captures the audio token ids used by the MusicGen decode step
5) Retokenizes the produced waveform with the same saved EnCodec and computes token-match between
   the generated token ids and the retokenized codes
"""

from pathlib import Path
import torch
import scipy.io
import numpy as np

from transformers import (
    AutoProcessor,
    AutoFeatureExtractor,
    EncodecModel,
    MusicgenForConditionalGeneration,
)

LOCAL_ENCODEC_DIR = Path("checkpoints/encodec_32khz")

FT_CHECKPOINT_PATH = "/home/wmar/wmar_audio/outputs/finetune/20260125-161511/checkpoint_epoch_99.pt"
FT_CHECKPOINT_PATH = "/home/wmar/wmar_audio/outputs/finetune/model_1_100_sched_1e6/20260126-204028/checkpoint_epoch_19.pt"

# FT_CHECKPOINT_PATH = "/home/wmar/wmar_audio/checkpoints/encodec_32khz/pytorch_model.bin"

LOCAL_ENCODEC_DIR.mkdir(parents=True, exist_ok=True)

# # 1) Download & save official EnCodec checkpoint to checkpoints/ (no guards)
# print("Downloading official EnCodec checkpoint and saving to", OUT_DIR)
# encodec = EncodecModel.from_pretrained("facebook/encodec_32khz")
# # save state_dict instead of using model.save_pretrained() to avoid unwrap_model/deepspeed import
# torch.save(encodec.state_dict(), OUT_DIR / "pytorch_model.bin")
# encodec.config.save_pretrained(str(OUT_DIR))
# AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz").save_pretrained(str(OUT_DIR))
# print("Saved EnCodec to", OUT_DIR)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2) Load MusicGen processor and model from HF
print("Loading MusicGen processor and model from Hugging Face (facebook/musicgen-medium)...")
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium").to(device)
model.eval()

# 3) Load local EnCodec model manually
print(f"Loading local EnCodec model architecture from {LOCAL_ENCODEC_DIR} ...")

# A. Initialize the Architecture (Weights will be overwritten)
# We use from_pretrained just to get the correct config/shape
local_encodec = EncodecModel.from_pretrained(str(LOCAL_ENCODEC_DIR)).to(device)

# B. Load your specific Finetuned Weights
raw_state_dict = torch.load(FT_CHECKPOINT_PATH, map_location=device)

print(f"Injecting weights from: {FT_CHECKPOINT_PATH}")
# Unwrap if nested
if "model_state" in raw_state_dict:
    raw_state_dict = raw_state_dict["model_state"]

# 3. TRANSLATION LOOP
print("Translating Meta keys to Hugging Face format...")
new_state_dict = {}

for key, value in raw_state_dict.items():
    new_key = key

    # Fix 1: "model" -> "layers"
    # Meta: encoder.model.0... -> HF: encoder.layers.0...
    if "encoder.model" in new_key:
        new_key = new_key.replace("encoder.model", "encoder.layers")
    if "decoder.model" in new_key:
        new_key = new_key.replace("decoder.model", "decoder.layers")
    if "quantizer.vq" in new_key:
        new_key = new_key.replace("quantizer.vq", "quantizer.layers")

    # Fix 2: "conv.conv" -> "conv"
    # Meta: ...conv.conv.weight -> HF: ...conv.weight
    # (Meta wraps Conv1d in a Streamable layer, HF unwraps it)
    if "conv.conv." in new_key:
        new_key = new_key.replace("conv.conv.", "conv.")
        
    # Fix 3: LSTM naming differences
    # Meta often uses "lstm" while HF might uses "lstm.weight_ih_l0" directly. 
    # Usually, the standard LSTM keys match, but keep an eye on this.

    new_state_dict[new_key] = value

# 4. Inject Weights
print("Injecting translated weights...")
missing, unexpected = local_encodec.load_state_dict(new_state_dict, strict=False)

if len(missing) > 0:
    print(f"[Warning] Missing keys: {len(missing)} (Likely internal buffers, usually safe)")
if len(unexpected) > 0:
    print(f"[Warning] Unexpected keys: {len(unexpected)} (Likely optimizer states if unwrapping failed)")

local_encodec.eval()

# D. Swap into MusicGen
model.audio_encoder = local_encodec
model.config.audio_encoder = local_encodec.config 

print("Swap complete.")


# Prepare a simple text prompt
texts = [
    "80s pop track with bassy drums and synth",
]

# Prepare the inputs using the processor (tokenizer + feature extractor)
print("Preparing inputs with processor...")
inputs = processor(text=texts, padding=True, return_tensors="pt")
# Move any tensors that will be forwarded to the model to the model device
for k, v in dict(inputs).items():
    if isinstance(v, torch.Tensor):
        inputs[k] = v.to(device)

# 3b) Monkeypatch model.audio_encoder.decode to capture the token ids passed for decoding
token_ids_captured = []
orig_decode = model.audio_encoder.decode

def capture_decode(output_ids, audio_scales=None, **kwargs):
    # make a CPU copy of the token ids for later analysis
    token_ids_captured.append(output_ids.detach().cpu().clone())
    return orig_decode(output_ids, audio_scales=audio_scales, **kwargs)

model.audio_encoder.decode = capture_decode

# 4) Generate audio (this will call our capture_decode and populate token_ids_captured)
print("Generating audio (this may take a while)...")
# Use a small number of tokens for quicker run for debugging; change max_new_tokens as desired
generated = model.generate(**inputs, max_new_tokens=256)

# generated is the waveform (or ModelOutput if return_dict_in_generate was used); MusicGen returns audio values
if isinstance(generated, torch.Tensor):
    audio_values = generated.detach().cpu()  # shape likely (batch, channels, seq_len)
else:
    # If HF returns a ModelOutput-like object, inspect probable attribute names
    try:
        audio_values = generated.sequences.detach().cpu()
    except Exception:
        # fallback: try to extract the returned audio_values directly
        raise RuntimeError("Could not extract waveform from generate() result; inspect `generated` object.")

print("Generated waveform shape:", audio_values.shape)

# 5) Obtain captured token ids
if len(token_ids_captured) == 0:
    raise RuntimeError("Token ids were not captured. Generation may not have called audio_encoder.decode as expected.")
captured_ids = token_ids_captured[0]  # tensor, e.g., shape (1, batch, num_codebooks, seq_len)
print("Captured token ids tensor shape:", captured_ids.shape)

# Convert captured_ids to a canonical shape: (frames, batch, codebooks, seq_len)
# The generation code in transformers appends a frame dimension at front (output_ids = output_ids[None,...]),
# so captured_ids should already be (1, batch, codebooks, seq_len) or similar.
captured_ids_np = captured_ids.numpy()

# 6) Retokenize the generated waveform with the saved EnCodec to get codes
# Ensure audio_values is on the encodec device
audio_for_encode = audio_values.to(device)

# The EncodecModel.encode API returns an output object with .audio_codes (frames, bsz, codebooks, seq_len)
print("Running EnCodec.encode on generated waveform to retokenize...")
encode_outputs = local_encodec.encode(audio_for_encode, return_dict=True)
# Many EncodecModel implementations return a dataclass / ModelOutput with .audio_codes
audio_codes = encode_outputs.audio_codes  # shape: (frames, batch, codebooks, seq_len)
print("Retokenized audio_codes shape:", audio_codes.shape)

# Convert to CPU numpy for comparison
audio_codes_np = audio_codes.detach().cpu().numpy()

# Align shapes: captured_ids_np should be (frames, batch, codebooks, seq_len) or (1, batch, codebooks, seq_len)
# audio_codes_np is expected to be (frames, batch, codebooks, seq_len)
# Compute token-match: equality between captured_ids and retokenized codes elementwise
# If shapes mismatch in length of seq_len, we will trim to min length (minimal safe operation)
min_frames = min(captured_ids_np.shape[0], audio_codes_np.shape[0])
min_batch = min(captured_ids_np.shape[1], audio_codes_np.shape[1])
min_codebooks = min(captured_ids_np.shape[2], audio_codes_np.shape[2])
min_seq = min(captured_ids_np.shape[3], audio_codes_np.shape[3])

cap = captured_ids_np[:min_frames, :min_batch, :min_codebooks, :min_seq]
retok = audio_codes_np[:min_frames, :min_batch, :min_codebooks, :min_seq]

print(cap.shape)

matches = (cap == retok)
token_match_rate = float(matches.mean())
print(f"Token match rate (fraction equal) between generated token ids and retokenized codes: {token_match_rate:.6f}")

# Save waveform(s) to wav files and save captured token ids + retokenized tokens to numpy files
sr = getattr(local_encodec.config, "sampling_rate", None) or 32000
sr = int(sr)
print("Using sampling rate:", sr)

# audio_values shape may be (batch, channels, seq_len) or (batch, seq_len) - normalize to (batch, channels, seq_len)
if audio_values.ndim == 2:
    audio_values = audio_values[:, None, :]
elif audio_values.ndim == 1:
    audio_values = audio_values[None, None, :]

for i in range(audio_values.shape[0]):
    wav = audio_values[i].numpy()
    # scipy expects shape (n_samples, n_channels) for int data; we transpose to (channels, samples) -> (samples, channels)
    wav_np = np.asarray(wav).T
    out_wav_path = Path(f"musicgen_out_{i}.wav")
    scipy.io.wavfile.write(str(out_wav_path), rate=sr, data=wav_np)
    print("Saved waveform to", out_wav_path)

# Print a small sample for debugging
print("Captured token ids sample (first frame, first batch, first codebook, first 32 tokens):")
print(cap[0, 0, 0, :32])
print("Retokenized codes sample (first frame, first batch, first codebook, first 32 tokens):")
print(retok[0, 0, 0, :32])
