import torch
import torchaudio
import argparse
from pathlib import Path
from encodec import EncodecModel
from encodec.utils import convert_audio
import numpy as np
from tqdm import tqdm
import random

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def get_token_match(model, audio, device):
    """
    Runs the Idempotence Test:
    1. Encode Original -> Codes_1
    2. Decode Codes_1 -> Reconstructed Audio
    3. Encode Reconstructed -> Codes_2
    4. Compare Codes_1 vs Codes_2
    """
    with torch.no_grad():
        # 1. First Encode
        encoded_frames = model.encode(audio)
        codes_1 = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]
        
        # 2. Decode
        rec_audio = model.decode(encoded_frames)
        
        # 3. Second Encode (Round Trip)
        encoded_frames_2 = model.encode(rec_audio)
        codes_2 = torch.cat([encoded[0] for encoded in encoded_frames_2], dim=-1)

        # 4. Compare
        # We must trim because EnCodec sometimes adds padding during decode
        min_len = min(codes_1.shape[-1], codes_2.shape[-1])
        c1 = codes_1[..., :min_len]
        c2 = codes_2[..., :min_len]
        
        matches = (c1 == c2).float().mean().item()
        return matches

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="/storage/data/FMA/fma_medium/098", help="Path to LibriTTS or FMA folder")
    parser.add_argument("--samples", type=int, default=50, help="Number of files to check")
    parser.add_argument("--disable_tf32", action="store_true", help="Force strict Float32 (Turn off TF32 acceleration)")
    parser.add_argument("--device", type=str, default="cuda")
    args = parser.parse_args()

    # --- 1. HARDWARE SETTINGS ---
    print(f"\n{'='*40}")
    print(f"Running on: {torch.cuda.get_device_name(0)}")
    
    if args.disable_tf32:
        print(">>> MODE: STRICT FLOAT32 (TF32 DISABLED)")
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    else:
        print(">>> MODE: DEFAULT (TF32 ENABLED if available)")
        # Ensure defaults are active
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    print(f"{'='*40}\n")

    set_seed(42)
    device = args.device

    # --- 2. LOAD MODEL ---
    print("Loading EnCodec (24khz)...")
    model = EncodecModel.encodec_model_24khz()
    model.set_target_bandwidth(6.0)
    model.to(device)
    model.eval()

    # --- 3. DATA COLLECTION ---
    paths = list(Path(args.data_dir).rglob("*.wav")) + list(Path(args.data_dir).rglob("*.mp3"))
    if not paths:
        print("No audio files found.")
        return
        
    # Shuffle and pick subset
    random.shuffle(paths)
    selected_paths = paths[:args.samples]

    # --- 4. EVALUATION LOOP ---
    scores = []
    print(f"Evaluating Token Match on {len(selected_paths)} files...")
    
    for path in tqdm(selected_paths):
        try:
            wav, sr = torchaudio.load(str(path))
            # Preprocess: Mono + 24khz + 5 seconds crop
            wav = convert_audio(wav, sr, model.sample_rate, model.channels)
            
            # Crop to exactly 5 seconds (avoid VRAM issues and keep consistent)
            target_len = int(5.0 * model.sample_rate)
            if wav.shape[-1] > target_len:
                start = 0
                wav = wav[..., start:start+target_len]
            elif wav.shape[-1] < target_len:
                # Pad if too short
                wav = torch.nn.functional.pad(wav, (0, target_len - wav.shape[-1]))

            wav = wav.unsqueeze(0).to(device)
            
            score = get_token_match(model, wav, device)
            scores.append(score)
            
        except Exception as e:
            # print(f"Skipping {path.name}: {e}")
            pass

    # --- 5. REPORT ---
    avg_match = sum(scores) / len(scores) if scores else 0
    print(f"\n{'-'*30}")
    print(f"Results for {args.samples} samples:")
    print(f"Avg Token Match: {avg_match:.4f}")
    print(f"{'-'*30}\n")

if __name__ == "__main__":
    main()
