import os
import torch
import numpy as np
import pandas as pd
import argparse
import sphn
from tqdm import tqdm
import random
import pickle
from scipy import special

from transformers import (
    AutoProcessor,
    EncodecModel,
    MusicgenForConditionalGeneration,
)
from evals.main_wm_music import build_stream_ngrams_from_full_stream

from models.moshi.utils import bool_inst
from models.musicgen import MusicGenWMGen
from watermark.engine import get_wm_window_hash, GENERATOR

from training import get_validation_augs, get_dummy_augs


def get_binomial_pval(x, n, p):
    """
    Calculates the p-value for a one-sided binomial test (greater).
    Args:
        x: The number of successes (e.g., number of matching bits).
        n: The number of trials (e.g., total number of bits).
        p: The hypothesized probability of success under the null hypothesis (e.g., 0.5 for random chance).
    Returns:
        The p-value.
    """
    # p_value = stats.binomtest(x,n,p=p,alternative='greater').pvalue
    p_value = special.betainc(x, 1 + n - x, p)
    return p_value

def seed_all(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # for multi-GPU setups
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def compute_watermark_scores(wm_stream, ngrams, audio_vocab_size, gamma, wm_seed, device='cpu', clustering_map=None):
    """Compute watermark scores for a given stream of tokens.
    ngrams may be:
      - shape (1, n): same ngram used for all tokens (backwards compatible)
      - shape (T, n): one ngram per token/time-step (use that per-token hash)
      - shape (1, 0): empty ngram (no context)
    """
    # Track unseen tokens
    seen_tokens = set()
    green_mask = torch.zeros_like(wm_stream, dtype=torch.bool)
    to_score_mask = torch.zeros_like(wm_stream, dtype=torch.bool)

    effective_vocab_size = int(clustering_map.max().item()) + 1 if clustering_map is not None else audio_vocab_size
    T = wm_stream.shape[-1]

    # print("Scoring effective vocab:", effective_vocab_size)

    # Precompute window hashes if we have a per-timestep ngrams matrix
    per_timestep_hashes = None
    if ngrams is not None and ngrams.numel() > 0 and ngrams.shape[0] == T:
        # This is executed for h > 0
        # ngrams: T x n
        per_timestep_hashes = get_wm_window_hash(ngrams, wm_seed, clustering_map=clustering_map)
        # ensure it's on CPU / convertible to int per element
        per_timestep_hashes = per_timestep_hashes.cpu()

    # If ngrams is a single-row or empty, compute single hash once
    single_hash_val = None
    if per_timestep_hashes is None:
        # This is executed for h = 0
        # pass through original behaviour: compute single hash for provided ngrams (could be 1 x n or 1 x 0)
        single_hash = get_wm_window_hash(ngrams, wm_seed, clustering_map=clustering_map)
        single_hash_val = int(single_hash[0].item())

    for ii, token in enumerate(wm_stream):
        # Decide which seed to use
        if per_timestep_hashes is not None:
            seed = int(per_timestep_hashes[ii].item())
        else:
            seed = single_hash_val

        GENERATOR.manual_seed(seed)
        vocab_perm = torch.randperm(effective_vocab_size, generator=GENERATOR)
        greenlist = vocab_perm[:int(gamma * effective_vocab_size)]  # list of tokens/clusters
        
        token_val = token.cpu().item()
        
        if clustering_map is not None:
             cluster_id = clustering_map[token.long()].item()
             is_green = cluster_id in greenlist
        else:
             is_green = token_val in greenlist

        green_mask[ii] = is_green
        if token_val not in seen_tokens:
            to_score_mask[ii] = 1
            seen_tokens.add(token_val)
            
    return green_mask, to_score_mask

def realign_delayed_codebooks(tokens):
    """
    Shifts MusicGen tokens to align them for EnCodec decoding.
    Stream k is delayed by k steps during generation.
    We must shift it LEFT by k steps to align audio.
    """
    # print(tokens.shape)
    B, K, T = tokens.shape
    aligned = torch.zeros_like(tokens)
    for k in range(K):
        # Shift left by k
        if T > k:
            aligned[:, k, :T-k] = tokens[:, k, k:]
    # print("aligned", tokens.shape)
    # raise
    return aligned

def run_watermark_eval(args, clustering_maps=None, config_name="standard"):
    """Generate audio with watermarks and evaluate watermark preservation"""
    # 1. Load Models (Swapped for MusicGen)
    device = args.device
    print(f"Loading MusicGen models on {device}...")
    processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
    model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium").to(device)
    encodec = EncodecModel.from_pretrained("facebook/encodec_32khz").to(device)

    # 2. Apply Weight Translation (from musicgen.py)
    if args.encodec_weight:
        print(f"Loading finetuned EnCodec weights from {args.encodec_weight}")
        raw_sd = torch.load(args.encodec_weight, map_location=device)
        if "model_state" in raw_sd: raw_sd = raw_sd["model_state"]
        translated_sd = {}
        for k, v in raw_sd.items():
            nk = k.replace("encoder.model", "encoder.layers").replace("decoder.model", "decoder.layers").replace("quantizer.vq", "quantizer.layers").replace("conv.conv.", "conv.")
            translated_sd[nk] = v
        encodec.load_state_dict(translated_sd, strict=False)
        model.audio_encoder = encodec # Ensure MusicGen uses the swapped encoder

    # 3. Initialize Wrapper
    lm_gen = MusicGenWMGen(
        model, 
        temp=args.temperature, 
        wm=args.wm_method, 
        wm_ngram=args.wm_ngram,
        wm_seed=args.wm_seed, 
        wm_streams=[int(s) for s in args.wm_streams],
        wm_aux_params={"delta": args.wm_delta, "gamma": args.wm_gamma, "clustering_maps": clustering_maps}
    )

    print(f"Watermarking config: method={lm_gen.wm}, streams={lm_gen.wm_streams}, "
          f"ngram={lm_gen.wm_ngram}, delta={lm_gen.wm_aux_params['delta']}")
    print(f"--- Running Configuration: {config_name} ---")

    # 4. Handle Prompts (Text File)
    with open(args.prompt_file, 'r') as f:
        prompts = [line.strip() for line in f if line.strip()]
    
    nsamples = len(prompts)
    if args.nsamples > 0:
        nsamples = min(args.nsamples, nsamples)
        prompts = prompts[:nsamples]

    global_watermark_results = []
    
    # Loop over samples in batches
    for batch_start in tqdm(range(0, nsamples, args.batch_size)):
        batch_size = min(args.batch_size, nsamples - batch_start)
        batch_texts = prompts[batch_start : batch_start + batch_size]
        
        inputs = processor(text=batch_texts, padding=True, return_tensors="pt").to(device)

        # 5. Generate
        wm_tokens_th = lm_gen.generate_watermarked(inputs, max_new_tokens=args.steps) # [B, K, T]
        wm_tokens_th = realign_delayed_codebooks(wm_tokens_th)
        
        # Decode to audio using the EnCodec model
        with torch.no_grad():
            # wm_tokens_th shape: [B, K, T]
            # EnCodec.decode expects [B, K, T] and returns a specific output object
            decoded_outputs = encodec.decode(wm_tokens_th[None, :], [None] * batch_size)
            batch_all_audio = decoded_outputs.audio_values # [B, 1, L]

        # 6. Evaluation Loop
        augs = get_validation_augs() if args.eval_aug else get_dummy_augs()
        for aug, _ in augs:
            aug.to(args.device)
        
        batch_audio_saved = batch_all_audio.clone()
        
        for validation_aug, strengths in augs:
            for strength in strengths:
                batch_aug_audio, _ = validation_aug(batch_audio_saved, None, strength)

                for idx in range(batch_size):
                    synced_audio = batch_aug_audio[idx:idx+1]

                    # Encode augmented audio (Roundtrip)
                    # MusicGen/Encodec: [1, 4, T]
                    tokens_roundtrip = encodec.encode(synced_audio).audio_codes.squeeze(0).squeeze(0)

                    # Get watermarked streams (Original)
                    # Use index slicing directly (0-3)
                    wm_tokens_orig = wm_tokens_th[idx] # [K, T]

                    orig_greens, orig_scored = [], []
                    greens, scored = [], []

                    # Analyze Codebooks
                    # Note: We use self-history for all streams in MusicGen parallel generation
                    for stream_id in args.wm_streams:
                        stream_id = int(stream_id)

                        # print("analyzing watermark on stream", stream_id)

                        # A. Original Scores
                        wm_stream = wm_tokens_orig[stream_id, :]
                        ngrams_orig = build_stream_ngrams_from_full_stream(wm_stream, args.wm_ngram, device='cpu')
                        s_map = clustering_maps.get(stream_id) if clustering_maps else None

                        g_mask, s_mask = compute_watermark_scores(
                            wm_stream, ngrams_orig, 2048, args.wm_gamma, args.wm_seed, clustering_map=s_map
                        )
                        orig_greens.append((g_mask * s_mask).float().sum().item())
                        orig_scored.append(s_mask.float().sum().item())

                        # B. Roundtrip Scores
                        if tokens_roundtrip is not None and stream_id < tokens_roundtrip.shape[0]:
                            wm_stream_rt = tokens_roundtrip[stream_id, :]
                            ngrams_rt = build_stream_ngrams_from_full_stream(wm_stream_rt, args.wm_ngram, device='cpu')
                            
                            g_mask_rt, s_mask_rt = compute_watermark_scores(
                                wm_stream_rt, ngrams_rt, 2048, args.wm_gamma, args.wm_seed, clustering_map=s_map
                            )
                            greens.append((g_mask_rt * s_mask_rt).float().sum().item())
                            scored.append(s_mask_rt.float().sum().item())
                        else:
                            greens.append(0)
                            scored.append(0)

                    # Calculate Stats
                    tot_orig_greens = float(sum(orig_greens))
                    tot_orig_scored = float(sum(orig_scored))
                    orig_pval = get_binomial_pval(tot_orig_greens, tot_orig_scored, args.wm_gamma)
                    
                    tot_greens = sum(greens)
                    tot_scored = sum(scored)
                    pval = get_binomial_pval(tot_greens, tot_scored, args.wm_gamma)
                    
                    global_idx = batch_start + idx
                    result = {
                        "config": config_name,
                        "idx": global_idx,
                        "aug_name": str(validation_aug),
                        "strength": strength,
                        "original_greens": orig_greens,
                        "original_ntoks": wm_tokens_orig.shape[-1],
                        "original_pval": orig_pval,
                        "greens": greens,
                        "scored": scored,
                        "ntoks": tokens_roundtrip.shape[-1],
                        "pval": pval,
                    }
                    global_watermark_results.append(result)

                    print(orig_pval, pval)
                    
                    # Save generated audio
                    if args.save_audio > 0 and global_idx < args.save_audio:
                        audio_output_dir = os.path.join(args.output_dir, f"audio_{config_name}")
                        os.makedirs(audio_output_dir, exist_ok=True)
                        aug_audio_np = batch_aug_audio[idx, 0].detach().cpu().numpy().astype(np.float32)
                        sphn.write_wav(
                            os.path.join(audio_output_dir, f'{validation_aug}_{strength}_{global_idx:03d}.wav'),
                            aug_audio_np, encodec.config.sampling_rate
                        )

        # Save Text Prompts
        with open(os.path.join(args.output_dir, f"generated_texts_{config_name}.txt"), "a", encoding="utf-8") as f:
            for idx in range(batch_size):
                f.write(f"{idx + batch_start:04d},{batch_texts[idx]}\n")

    # Save summary
    summary = {'config': vars(args), 'results': global_watermark_results}
    torch.save(summary, os.path.join(args.output_dir, f'summary_{config_name}.pt'))

    # Calculate statistics
    df_data = [
        {
            "idx": wmr["idx"],
            "aug_name": wmr["aug_name"],
            "strength": str(wmr["strength"]),
            "greens": sum(wmr["greens"]),
            "scored": sum(wmr["scored"]),
            "ntoks": wmr["ntoks"],
            "pval": wmr["pval"],
            "logpval": -np.log10(wmr["pval"]) if wmr["pval"] is not None and wmr["pval"] > 0 else None,
        }
        for wmr in global_watermark_results
    ]

    df = pd.DataFrame(df_data)
    numeric_cols_for_mean = ["greens", "scored", "ntoks", "pval", "logpval"]
    cols_to_aggregate = [col for col in numeric_cols_for_mean if col in df.columns]
    
    mean_df = df.groupby(["aug_name", "strength"])[cols_to_aggregate].agg("mean")
    mean_df.to_csv(os.path.join(args.output_dir, f'summary_{config_name}.csv'))
        
    pd.set_option('display.max_rows', None)
    print(mean_df)
    df.to_csv(os.path.join(args.output_dir, f'results_{config_name}.csv'), index=False)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu")
    parser.add_argument("--seed", type=int, default=42424242)
    parser.add_argument("--steps", type=int, default=256)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--prompt_file", type=str, required=True, help="Path to txt file with prompts")
    parser.add_argument("--nsamples", type=int, default=-1)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--wm_method", type=str, default="maryland")
    parser.add_argument("--wm_streams", nargs='+', default=[0, 1, 2, 3], help="Stream indices (0-3 for MusicGen)")
    parser.add_argument("--wm_delta", type=float, default=2.0)
    parser.add_argument("--wm_gamma", type=float, default=0.25)
    parser.add_argument("--wm_ngram", type=int, default=0)
    parser.add_argument("--wm_seed", type=int, default=0)
    parser.add_argument("--wm_clustering", type=bool_inst, default=False)
    parser.add_argument("--encodec_weight", type=str, default=None)
    parser.add_argument("--save_audio", type=int, default=10)
    parser.add_argument("--eval_aug", type=bool_inst, default=True)         
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    seed_all(args.seed)

    # Determine what to run
    configs_to_run = [{"method": None, "maps": None}] # Default standard
    # configs_to_run = []
    all_maps = None

    if args.wm_clustering:
        CLUSTERS_PKL = "/home/wmar/wmar_audio/models/embeddings/encodec_leiden_clusterings_trainonly_allparams.pkl"
        # channel names must match how the pickle was constructed

        # load the single pickle that contains per-channel dicts
        with open(CLUSTERS_PKL, "rb") as f:
            clusterings = pickle.load(f)  # expected structure: {channel_name: {(cnt,res): labels, ...}, ...}

        # # prints [0, 1, 2, 3]
        # print(clusterings.keys())
        # raise

        # read available (min_count, res) keys from the first channel entry
        keys = sorted(clusterings[0].keys())  # expect keys like (min_count, resolution)

        vocab_size = 2048
        method = "leiden"

        for key in keys:
            cnt, res = key  # intentionally let this raise if key is not a 2-tuple
            current_config_maps = {}
            for s in args.wm_streams:
                stream_id = int(s)
                channel_name = stream_id
                picked = clusterings[channel_name][key]  # let KeyError surface if missing
                cmap_arr = np.asarray(picked)
                cmap_tensor = torch.as_tensor(cmap_arr, device=args.device, dtype=torch.long)

                full_map = torch.full((vocab_size,), -1, device=args.device, dtype=torch.long)
                limit = min(len(cmap_tensor), vocab_size)
                full_map[:limit] = cmap_tensor[:limit]
                unmapped = (full_map == -1)
                k_clusters = int(cmap_tensor.max().item()) + 1
                if unmapped.any():
                    full_map[unmapped] = torch.arange(k_clusters, k_clusters + unmapped.sum(), device=args.device)

                current_config_maps[stream_id] = full_map

            configs_to_run.append({
                "method": method,
                "min_count": int(cnt),
                "res": float(res),
                "maps": current_config_maps
            })

    print(f"Starting execution for {len(configs_to_run)} configurations")

    # Run watermark evaluation for each method
    for config in configs_to_run:
        if config["method"] is None:
            config_name = "standard"
        else:
            config_name = f"{config['method']}_{config['min_count']}_res{config['res']}"

        # if not config_name.startswith("leiden_1_res0") and not config_name.startswith("standard"):
        #     continue

        # if config_name != "leiden_1_res1.0":
        #     continue

        # if os.path.exists(os.path.join(args.output_dir, f'summary_{config_name}.csv')):
        #     continue
        
        run_watermark_eval(args, clustering_maps=config["maps"], config_name=config_name)


if __name__ == "__main__":
    main()
