import torch
import sys

def analyze_raw_structure(name, path):
    print(f"\n{'='*20} ANALYZING: {name} {'='*20}")
    print(f"Path: {path}")
    
    try:
        # Load exactly as is, no map_location tricks that might hide device issues
        obj = torch.load(path, map_location='cpu')
    except Exception as e:
        print(f"CRITICAL ERROR LOADING FILE: {e}")
        return

    print(f"Type: {type(obj)}")

    if not isinstance(obj, dict):
        print("Structure: Not a dictionary (likely a raw TorchScript object or list).")
        return

    # 1. Print Root Keys
    keys = list(obj.keys())
    print(f"Root Keys ({len(keys)}): {keys[:5]} ...")

    # 2. Heuristic: Is this a checkpoint or a state dict?
    # We look for a key that contains the actual weights.
    candidate_weight_key = None
    
    # Check if root IS the state dict (contains layer weights directly)
    if any("encoder" in k or "quantizer" in k for k in keys[:20]):
        print("\n>>> STRUCTURE TYPE: [Flat State Dict]")
        print("    (This file IS the weights. Ready for load_state_dict)")
        sample_keys = keys[:5]
    else:
        print("\n>>> STRUCTURE TYPE: [Nested Checkpoint]")
        print("    (This file CONTAINS the weights, usually mixed with optimizer states)")
        
        # Try to find the weights sub-dictionary
        common_subkeys = ["model_state", "state_dict", "model", "model_state_dict", "student"]
        found = False
        for subkey in common_subkeys:
            if subkey in keys and isinstance(obj[subkey], dict):
                print(f"    -> Weights found inside key: ['{subkey}']")
                candidate_weight_key = subkey
                sample_keys = list(obj[subkey].keys())[:5]
                found = True
                break
        
        if not found:
            print("    [ERROR] Could not automatically identify which key holds the weights.")
            sample_keys = []

    # 3. Print Sample Keys (Prefix Check)
    print("\n>>> KEY FORMAT SAMPLE (First 5):")
    for k in sample_keys:
        print(f"    {k}")

    # 4. Check specific EnCodec prefixes
    if sample_keys:
        print("\n>>> COMPATIBILITY CHECK:")
        first = sample_keys[0]
        if first.startswith("module."):
            print("    [WARN] Keys have 'module.' prefix (Saved from DDP). Needs strict=False or cleaning.")
        elif first.startswith("encoder.") or first.startswith("quantizer."):
            print("    [OK] Keys look standard (start with encoder/quantizer).")
        else:
            print(f"    [WARN] Unknown prefix format: '{first.split('.')[0]}...'")

if __name__ == "__main__":
    path_a = "/home/wmar/wmar_audio/checkpoints/encodec_32khz/pytorch_model.bin"
    path_b = "/home/wmar/wmar_audio/outputs/finetune/20260125-161511/checkpoint_epoch_99.pt"

    analyze_raw_structure("PATH A (Original)", path_a)
    analyze_raw_structure("PATH B (Finetuned)", path_b)