# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
"""
Run watermark generation and evaluation (CSM Version).
(omitted docstring usage example)
"""

import os
import torch
import numpy as np
import pandas as pd
import argparse
from huggingface_hub import hf_hub_download
import sentencepiece
import sphn
import glob
from tqdm import tqdm
import time
import random
import pickle
from collections import defaultdict
from scipy import stats, special

# WMAR & Common Imports
from models.moshi.models import loaders
from models.moshi.utils import bool_inst

from training import get_validation_augs, get_dummy_augs
from watermark.engine import get_wm_window_hash, GENERATOR
from watermark.sync import SyncPattern

# CSM Imports
from models.csm.models import Model, ModelArgs

# Disable Triton compilation for compatibility
os.environ["NO_TORCH_COMPILE"] = "1"

def get_binomial_pval(x, n, p):
    p_value = special.betainc(x, 1 + n - x, p)
    return p_value

def seed_all(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def load_clustering_maps(clustering_dir, target_min_count=None, device="cpu", vocab_size=2048):
    maps = defaultdict(lambda: defaultdict(dict))
    channel_suffixes = {0: "first_0", 1: "rest_0", 2: "rest_1", 3: "rest_2"}
    for c, suffix in channel_suffixes.items():
        for method in ["leiden", "louvain"]:
            files = sorted(glob.glob(os.path.join(clustering_dir, f"*{method}*{suffix}*.pkl")))
            if not files:
                continue
            try:
                with open(files[0], "rb") as f:
                    data = pickle.load(f)
                keys_to_load = []
                if target_min_count is not None:
                    if target_min_count in data:
                        keys_to_load.append(target_min_count)
                    else:
                        available_keys = sorted(data.keys())
                        if available_keys:
                            mid = len(available_keys) // 2
                            fb_key = available_keys[mid]
                            print(f"  > Ch {c} {method}: Target {target_min_count} missing. Using {fb_key}")
                            keys_to_load.append(fb_key)
                else:
                    keys_to_load = sorted(data.keys())
                for key in keys_to_load:
                    cmap = torch.as_tensor(data[key], device=device, dtype=torch.long)
                    full_map = torch.full((vocab_size,), -1, device=device, dtype=torch.long)
                    limit = min(len(cmap), vocab_size)
                    full_map[:limit] = cmap[:limit]
                    unmapped = (full_map == -1)
                    k_clusters = int(cmap.max().item()) + 1
                    if unmapped.any():
                        full_map[unmapped] = torch.arange(k_clusters, k_clusters + unmapped.sum(), device=device)
                    maps[c][method][key] = (int(full_map.max().item()) + 1, full_map)
            except Exception as e:
                print(f"Failed to load {files[0]}: {e}")
    return maps

def compute_watermark_scores(wm_stream, ngrams, audio_vocab_size, gamma, wm_seed, device='cpu', clustering_map=None):
    seen_tokens = set()
    green_mask = torch.zeros_like(wm_stream, dtype=torch.bool)
    to_score_mask = torch.zeros_like(wm_stream, dtype=torch.bool)
    effective_vocab_size = int(clustering_map.max().item()) + 1 if clustering_map is not None else audio_vocab_size
    for ii, token in enumerate(wm_stream):
        window_hash = get_wm_window_hash(ngrams, wm_seed, clustering_map=clustering_map)
        GENERATOR.manual_seed(window_hash[0].item())
        vocab_perm = torch.randperm(effective_vocab_size, generator=GENERATOR)
        greenlist = vocab_perm[:int(gamma * effective_vocab_size)]
        token_val = token.cpu().item()
        if clustering_map is not None:
            cluster_id = clustering_map[token.long()].item()
            is_green = cluster_id in greenlist
        else:
            is_green = token_val in greenlist
        green_mask[ii] = is_green
        if token_val not in seen_tokens:
            to_score_mask[ii] = 1
            seen_tokens.add(token_val)
    return green_mask, to_score_mask

def run_watermark_eval(args, clustering_maps=None, config_name="standard"):
    """Generate audio with watermarks and evaluate watermark preservation (CSM Backend)"""
    # Load text tokenizer
    text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer)

    # Load finetuned Mimi(s) — keep them at their native weights and then force their num_codebooks to match the CSM model
    mimi = loaders.get_mimi(args.mimi_weight, args.device)
    mimi_ori = loaders.get_mimi(args.mimi_weight_ori, args.device)

    # Load CSM model
    print("Loading CSM model via from_pretrained...")
    csm_model = Model.from_pretrained(args.checkpoint_path)

    # CSM was trained with 32 streams; read the model's config (default to 32)
    K_model = int(getattr(csm_model.config, "audio_num_codebooks", 32))
    print(f"CSM model reports audio_num_codebooks = {K_model}")

    # Force the finetuned Mimi(s) to present K_model codebooks so they match the CSM token layout.
    # This must happen BEFORE we call setup_caches() on the model.
    try:
        mimi.set_num_codebooks(K_model)
        mimi_ori.set_num_codebooks(K_model)
        print(f"Set Mimi tokenizer(s) num_codebooks -> {K_model}")
    except Exception as e:
        # If set_num_codebooks fails, warn and proceed — but audio quality will likely suffer.
        print(f"Warning: failed to set Mimi num_codebooks to {K_model}: {e}")

    # Move model to device and eval
    csm_model.to(args.device)
    csm_model.eval()

    # Clear any existing KV caches so setup_caches will re-allocate with current K_model
    try:
        csm_model.reset_caches()
    except Exception:
        # If reset_caches isn't supported or fails, warn. We will still attempt setup_caches and hope for the best.
        print("Warning: csm_model.reset_caches() failed or not available.")

    # Watermark parameters
    args.wm_streams = [int(x) for x in args.wm_streams]
    if args.wm_method.lower() == "none":
        watermark_params = None
    else:
        watermark_params = {
            "method": args.wm_method,
            "streams": args.wm_streams,
            "ngram": args.wm_ngram,
            "seed": args.wm_seed,
            "aux_params": {
                "delta": args.wm_delta,
                "gamma": args.wm_gamma,
            }
        }
    if clustering_maps and watermark_params:
        watermark_params["aux_params"]["clustering_maps"] = clustering_maps

    audio_vocab_size = 2048

    if args.wm_sync:
        sync_pattern = SyncPattern()
        sync_pattern.to(args.device)

    # Prepare prompts
    nsamples = args.nsamples
    audio_files = []
    if args.use_prompts and args.audio_dir:
        for ext in ['*.wav', '*.mp3', '*.ogg', '*.flac']:
            audio_files.extend(glob.glob(os.path.join(args.audio_dir, ext)))
        audio_files = sorted(audio_files)
        if args.nsamples > 0:
            audio_files = audio_files[:args.nsamples]
        nsamples = min(nsamples, len(audio_files))

    global_watermark_results = []

    # Use K_model as the canonical K (CSM-trained)
    K = K_model
    print(f"Using K (codebooks) = {K}")

    for batch_start in tqdm(range(0, nsamples, args.batch_size)):
        batch_size = min(args.batch_size, nsamples - batch_start)

        # --- Prompt Processing: encode with finetuned Mimi and pad using its channel dim (which we set to K_model) ---
        if args.use_prompts and audio_files:
            batch_files = audio_files[batch_start:batch_start + batch_size]
            prompt_codes_list = []
            for audio_path in batch_files:
                sample_pcm, sample_sr = sphn.read(audio_path, duration_sec=args.duration_sec)
                sample_pcm = sphn.resample(sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=mimi.sample_rate)
                sample_pcm = torch.tensor(sample_pcm, device=args.device).unsqueeze(0)
                with torch.no_grad():
                    prompt_code = mimi_ori.encode(sample_pcm)  # (1, K, T)
                    prompt_codes_list.append(prompt_code)

            max_len = max([c.shape[-1] for c in prompt_codes_list])
            padded_list = []
            for c in prompt_codes_list:
                pad_amt = max_len - c.shape[-1]
                if pad_amt > 0:
                    pad_block = torch.zeros(1, c.shape[-2], pad_amt, device=c.device, dtype=c.dtype)
                    c = torch.cat([pad_block, c], dim=-1)
                padded_list.append(c)
            current_prompt_codes = torch.cat(padded_list, dim=0)  # (B, K, T)
        else:
            current_prompt_codes = torch.zeros(batch_size, K, 1, device=args.device, dtype=torch.long)

        # Build CSM token frames: (B, T, K+1)
        B, C, T = current_prompt_codes.shape
        audio_part = current_prompt_codes.permute(0, 2, 1)  # (B, T, K)
        text_part = torch.ones(B, T, 1, device=args.device, dtype=torch.long)
        curr_tokens = torch.cat([audio_part, text_part], dim=2)  # (B, T, K+1)

        with torch.no_grad():
            # Ensure caches are setup for this batch AFTER Mimi was configured
            csm_model.setup_caches(batch_size)

            # PREFILL
            input_pos = torch.arange(0, T, device=args.device).unsqueeze(0).expand(B, T)
            mask = torch.ones_like(curr_tokens, dtype=torch.bool)

            first_gen_frame = csm_model.generate_frame(curr_tokens, mask, input_pos, args.temperature, 50, watermark_params)

            # GENERATE
            generated_frames = [first_gen_frame]
            new_text = torch.ones(B, 1, device=args.device, dtype=torch.long)
            new_frame_full = torch.cat([first_gen_frame, new_text], dim=1).unsqueeze(1)  # (B, 1, K+1)
            curr_tokens = torch.cat([curr_tokens, new_frame_full], dim=1)

            for _ in range(args.steps - 1):
                curr_t = curr_tokens.shape[1] - 1
                pos = torch.tensor([curr_t], device=args.device).expand(B, 1)
                mask = torch.ones_like(curr_tokens, dtype=torch.bool)

                next_frame = csm_model.generate_frame(curr_tokens, mask, pos, args.temperature, 50, watermark_params)
                generated_frames.append(next_frame)

                new_frame_full = torch.cat([next_frame, new_text], dim=1).unsqueeze(1)
                curr_tokens = torch.cat([curr_tokens, new_frame_full], dim=1)

            gen_codes = torch.stack(generated_frames, dim=2)  # (B, K, S)
            # Decode using your finetuned Mimi (we already set it to K)
            batch_all_audio = mimi.decode(gen_codes)

        # --- Evaluation Logic ---
        if args.wm_sync:
            batch_all_audio = sync_pattern.get_sync_wm(batch_all_audio, alpha=0.5)

        augs = get_validation_augs() if args.eval_aug else get_dummy_augs()
        for aug, _ in augs:
            aug.to(args.device)
        batch_audio_saved = batch_all_audio.clone()

        for validation_aug, strengths in augs:
            for strength in strengths:
                batch_aug_audio, _ = validation_aug(batch_audio_saved, None, strength)
                if args.wm_sync:
                    detection_results = sync_pattern.detect_sync_wm(batch_aug_audio)  # b s

                for idx in range(batch_size):
                    synced_audio = batch_aug_audio[idx:idx+1]
                    if args.wm_sync:
                        detection_score = detection_results[idx].mean()
                        if np.abs(detection_score - 0.5) < 0.25:
                            speedup, shift = sync_pattern.get_speedup_and_shift(detection_results[idx])
                            synced_audio = sync_pattern.invert(synced_audio, speedup, shift)

                    # Roundtrip encoding with the finetuned Mimi (K_mimi channels)
                    tokens_roundtrip = mimi.encode(synced_audio)

                    # Select wm streams in Mimi-space (user-provided streams are interpreted as 1-based Mimi codebook indices)
                    wm_indices_mimi = [s - 1 for s in args.wm_streams]
                    wm_tokens = gen_codes[idx, wm_indices_mimi, :]  # (W, S)
                    wm_tokens_roundtrip = tokens_roundtrip[0, wm_indices_mimi, :] if tokens_roundtrip is not None else None

                    ngrams = torch.zeros((1, 0), device='cpu')
                    orig_greens, orig_scored = [], []

                    for stream_idx in range(wm_tokens.shape[0]):
                        wm_stream = wm_tokens[stream_idx, :]
                        s_map = clustering_maps.get(args.wm_streams[stream_idx]) if clustering_maps else None
                        green_mask, to_score_mask = compute_watermark_scores(wm_stream, ngrams, audio_vocab_size, args.wm_gamma, args.wm_seed, clustering_map=s_map)
                        orig_greens.append((green_mask * to_score_mask).float().sum().item())
                        orig_scored.append(to_score_mask.float().sum().item())

                    greens, scored = [], []
                    if wm_tokens_roundtrip is not None:
                        for stream_idx in range(wm_tokens_roundtrip.shape[0]):
                            wm_stream = wm_tokens_roundtrip[stream_idx, :]
                            s_map = clustering_maps.get(args.wm_streams[stream_idx]) if clustering_maps else None
                            green_mask, to_score_mask = compute_watermark_scores(wm_stream, ngrams, audio_vocab_size, args.wm_gamma, args.wm_seed, clustering_map=s_map)
                            greens.append((green_mask * to_score_mask).float().sum().item())
                            scored.append(to_score_mask.float().sum().item())

                    tot_orig_greens = float(sum(orig_greens))
                    tot_orig_scored = float(sum(orig_scored))
                    orig_pval = get_binomial_pval(tot_orig_greens, tot_orig_scored, args.wm_gamma)
                    if wm_tokens_roundtrip is not None:
                        tot_greens = sum(greens)
                        tot_scored = sum(scored)
                        pval = get_binomial_pval(tot_greens, tot_scored, args.wm_gamma)
                    else:
                        pval = None

                    global_idx = batch_start + idx
                    result = {
                        "config": config_name,
                        "idx": global_idx,
                        "aug_name": str(validation_aug),
                        "strength": strength,
                        "original_greens": orig_greens,
                        "original_ntoks": wm_tokens.shape[-1],
                        "original_pval": orig_pval,
                        "greens": greens,
                        "scored": scored,
                        "ntoks": wm_tokens_roundtrip.shape[-1] if wm_tokens_roundtrip is not None else 0,
                        "pval": pval,
                    }
                    global_watermark_results.append(result)

                    if args.save_audio > 0 and global_idx < args.save_audio:
                        audio_output_dir = os.path.join(args.output_dir, f"audio_{config_name}")
                        os.makedirs(audio_output_dir, exist_ok=True)
                        aug_audio_np = batch_aug_audio[idx, 0].detach().cpu().numpy().astype(np.float32)
                        sphn.write_wav(
                            os.path.join(audio_output_dir, f'{validation_aug}_{strength}_{global_idx:03d}.wav'),
                            aug_audio_np,
                            mimi.sample_rate,
                        )

    summary = {'config': config_name, 'results': global_watermark_results}
    torch.save(summary, os.path.join(args.output_dir, f'summary_{config_name}.pt'))

    df_data = [
        {
            "idx": wmr["idx"],
            "aug_name": wmr["aug_name"],
            "strength": str(wmr["strength"]),
            "greens": sum(wmr["greens"]),
            "scored": sum(wmr["scored"]),
            "ntoks": wmr["ntoks"],
            "pval": wmr["pval"],
            "logpval": -np.log10(wmr["pval"]) if wmr["pval"] is not None and wmr["pval"] > 0 else None,
        }
        for wmr in global_watermark_results
    ]

    df = pd.DataFrame(df_data)
    numeric_cols_for_mean = ["greens", "scored", "ntoks", "pval", "logpval"]
    cols_to_aggregate = [col for col in numeric_cols_for_mean if col in df.columns]

    mean_df = df.groupby(["aug_name", "strength"])[cols_to_aggregate].agg("mean")
    mean_df.to_csv(os.path.join(args.output_dir, f'summary_{config_name}.csv'))

    pd.set_option('display.max_rows', None)
    print(mean_df)
    df.to_csv(os.path.join(args.output_dir, f'results_{config_name}.csv'), index=False)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu")
    parser.add_argument("--seed", type=int, default=42424242)
    parser.add_argument("--hf_repo", type=str, default="kyutai/moshiko-pytorch-bf16")
    parser.add_argument("--steps", type=int, default=200)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--use_prompts", type=bool_inst, default=True)
    parser.add_argument("--audio_dir", type=str, help="Directory containing audio files for prompts")
    parser.add_argument("--duration_sec", type=float, default=None)
    parser.add_argument("--nsamples", type=int, default=-1)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--wm_method", type=str, default="maryland")
    parser.add_argument("--wm_streams", nargs='+', default=[1])
    parser.add_argument("--wm_delta", type=float, default=8.0)
    parser.add_argument("--wm_gamma", type=float, default=0.25)
    parser.add_argument("--wm_ngram", type=int, default=0)
    parser.add_argument("--wm_seed", type=int, default=0)
    parser.add_argument("--wm_sync", type=bool_inst, default=False)
    parser.add_argument("--wm_clustering", type=bool_inst, default=False)
    parser.add_argument("--tokenizer", type=str)
    parser.add_argument("--checkpoint_path", type=str, help="Path to CSM model checkpoint or HF Hub ID")
    parser.add_argument("--model_size", type=str, default="1B", choices=["1B", "100M"])
    parser.add_argument("--mimi_weight", type=str)
    parser.add_argument("--mimi_weight_ori", type=str)
    parser.add_argument("--save_audio", type=int, default=10)
    parser.add_argument("--save_tokens", type=int, default=0)
    parser.add_argument("--eval_aug", type=bool_inst, default=True)

    args = parser.parse_args()

    if args.mimi_weight is None or args.mimi_weight.lower() == "none":
        args.mimi_weight = None
    if args.mimi_weight is None:
        args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
    if args.mimi_weight_ori is None:
        args.mimi_weight_ori = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
    if args.tokenizer is None:
        args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME)

    os.makedirs(args.output_dir, exist_ok=True)
    seed_all(args.seed)

    configs_to_run = [{"method": None, "maps": None}]
    all_maps = None

    if args.wm_clustering:
        clustering_dir = "/home/AlignedIS-dev/models/embeddings/clusterings/mimi_*_rvq_*.pkl"
        if '*' in clustering_dir:
            clustering_dir = os.path.dirname(clustering_dir)

        all_maps = load_clustering_maps(clustering_dir, device=args.device)

        if all_maps and 0 in all_maps:
            found_methods = list(all_maps[0].keys())
            for m in found_methods:
                keys_available = sorted(list(all_maps[0][m].keys()))
                for key in keys_available:
                    current_config_maps = {}
                    valid_config = True
                    for s in args.wm_streams:
                        c = int(s) - 1
                        if c in all_maps and m in all_maps[c] and key in all_maps[c][m]:
                            _, tmap = all_maps[c][m][key]
                            current_config_maps[int(s)] = tmap
                        else:
                            valid_config = False
                            break
                    if valid_config:
                        configs_to_run.append({"method": m, "min_count": key, "maps": current_config_maps})

    print(f"Starting execution for {len(configs_to_run)} configurations")

    for config in configs_to_run:
        if config["method"] is None:
            config_name = "standard"
        else:
            config_name = f"{config['method']}_{config['min_count']}"

        run_watermark_eval(args, clustering_maps=config["maps"], config_name=config_name)

if __name__ == "__main__":
    main()
