import os
import torch
import numpy as np
import pandas as pd
import argparse
import sphn
import json
import pickle
from tqdm import tqdm

from transformers import (
    AutoProcessor,
    EncodecModel,
    MusicgenForConditionalGeneration,
)

from evals.main_wm_select_new import get_binomial_pval, seed_all, compute_watermark_scores
from evals.main_wm_music import build_stream_ngrams_from_full_stream

from models.moshi.utils import bool_inst
from models.musicgen import MusicGenWMGen

from training import get_validation_augs, get_dummy_augs

# --- HELPER FUNCTIONS ---

def load_clustering_maps(pkl_path, select_map, device='cpu', vocab_size=2048, channel_names=None):
    """
    Simple loader that only grabs the specific (count, res) configs requested in select_map.
    Ensures keys are INTEGERS so the generator/evaluator can find them.
    """
    print(f"Loading maps from {pkl_path}...")
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)

    # Output structure: maps[stream_int][method][cnt] = (raw, tensor)
    maps = {}

    for stream_str, config_str in select_map.items():
        # 1. Parse config string "leiden_10_res1.0" -> method="leiden", cnt=10, res=1.0
        left, res_part = config_str.split("_res", 1)
        method, cnt_str = left.rsplit("_", 1)
        cnt, res = int(cnt_str), float(res_part)

        # 2. Find correct key in pickle (try int, then quantizer string)
        # Note: We look up in pickle using the stream_str logic, but SAVE to maps using int
        pkl_key = int(stream_str) if int(stream_str) in data else f"quantizer.layers.{stream_str}"
        
        if pkl_key not in data or (cnt, res) not in data[pkl_key]:
            print(f"MISSING: Stream {stream_str} (key {pkl_key}) config {cnt}, {res}")
            continue

        # 3. Build Tensor
        raw_list = data[pkl_key][(cnt, res)]
        cmap = torch.as_tensor(raw_list, device=device, dtype=torch.long)
        
        # Create full map (fill -1s)
        full_map = torch.full((vocab_size,), -1, device=device, dtype=torch.long)
        limit = min(len(cmap), vocab_size)
        full_map[:limit] = cmap[:limit]
        
        # Fill unmapped tokens
        unmapped = (full_map == -1)
        if unmapped.any():
            start_id = int(cmap.max().item()) + 1
            full_map[unmapped] = torch.arange(start_id, start_id + unmapped.sum(), device=device)

        # 4. Store using INTEGER KEY
        stream_id = int(stream_str)
        if stream_id not in maps: maps[stream_id] = {}
        if method not in maps[stream_id]: maps[stream_id][method] = {}
        
        maps[stream_id][method][cnt] = (raw_list, full_map)

    return maps

def realign_delayed_codebooks(tokens):
    """
    Shifts MusicGen tokens to align them for EnCodec decoding.
    Stream k is delayed by k steps during generation.
    We must shift it LEFT by k steps to align audio.
    """
    B, K, T = tokens.shape
    aligned = torch.zeros_like(tokens)
    for k in range(K):
        # Shift left by k
        if T > k:
            aligned[:, k, :T-k] = tokens[:, k, k:]
    return aligned

def add_delay_to_codebooks(tokens):
    """
    Shifts aligned EnCodec tokens RIGHT to match MusicGen's delayed generation pattern.
    Stream k is delayed by k steps. Inverse of realign_delayed_codebooks.
    """
    B, K, T = tokens.shape
    delayed = torch.zeros_like(tokens)
    for k in range(K):
        # Shift right by k
        if T > k:
            delayed[:, k, k:] = tokens[:, k, :T-k]
    return delayed

def run_watermark_eval(args, clustering_maps=None, config_name="standard"):
    """Generate audio with watermarks and evaluate watermark preservation"""
    # 1. Load Models
    device = args.device
    print(f"Loading MusicGen models on {device}...")
    processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
    model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium").to(device)
    encodec = EncodecModel.from_pretrained("facebook/encodec_32khz").to(device)

    # 2. Apply Weight Translation
    if args.encodec_weight:
        print(f"Loading finetuned EnCodec weights from {args.encodec_weight}")
        raw_sd = torch.load(args.encodec_weight, map_location=device)
        if "model_state" in raw_sd: raw_sd = raw_sd["model_state"]
        translated_sd = {}
        for k, v in raw_sd.items():
            nk = k.replace("encoder.model", "encoder.layers").replace("decoder.model", "decoder.layers").replace("quantizer.vq", "quantizer.layers").replace("conv.conv.", "conv.")
            translated_sd[nk] = v
        encodec.load_state_dict(translated_sd, strict=False)
        model.audio_encoder = encodec

    # 3. Initialize Wrapper
    lm_gen = MusicGenWMGen(
        model, 
        temp=args.temperature, 
        wm=args.wm_method, 
        wm_ngram=args.wm_ngram,
        wm_seed=args.wm_seed, 
        wm_streams=[int(s) for s in args.wm_streams],
        wm_aux_params={"delta": args.wm_delta, "gamma": args.wm_gamma, "clustering_maps": clustering_maps}
    )

    print(f"Watermarking config: method={lm_gen.wm}, streams={lm_gen.wm_streams}, "
          f"ngram={lm_gen.wm_ngram}, delta={lm_gen.wm_aux_params['delta']}")
    print(f"--- Running Configuration: {config_name} ---")

    # 4. Handle Prompts
    with open(args.prompt_file, 'r') as f:
        prompts = [line.strip() for line in f if line.strip()]
    
    nsamples = len(prompts)
    if args.nsamples > 0:
        nsamples = min(args.nsamples, nsamples)
        prompts = prompts[:nsamples]

    global_watermark_results = []
    
    # Loop over samples
    for batch_start in tqdm(range(0, nsamples, args.batch_size)):
        batch_size = min(args.batch_size, nsamples - batch_start)
        batch_texts = prompts[batch_start : batch_start + batch_size]
        
        inputs = processor(text=batch_texts, padding=True, return_tensors="pt").to(device)

        # 5. Generate [B, K, T] (Raw Delayed)
        wm_tokens_th = lm_gen.generate_watermarked(inputs, max_new_tokens=args.steps)
        
        # Realign for Audio Decoding ONLY
        wm_tokens_for_audio = realign_delayed_codebooks(wm_tokens_th)
        
        # Decode to audio
        with torch.no_grad():
            decoded_outputs = encodec.decode(wm_tokens_for_audio[None, :], [None] * batch_size)
            batch_all_audio = decoded_outputs.audio_values # [B, 1, L]

        # 6. Evaluation Loop
        augs = get_validation_augs() if args.eval_aug else get_dummy_augs()
        for aug, _ in augs:
            aug.to(args.device)
        
        batch_audio_saved = batch_all_audio.clone()
        
        for validation_aug, strengths in augs:
            for strength in strengths:
                batch_aug_audio, _ = validation_aug(batch_audio_saved, None, strength)

                for idx in range(batch_size):
                    synced_audio = batch_aug_audio[idx:idx+1]

                    # Encode augmented audio -> [1, 4, T] (Aligned)
                    tokens_roundtrip_aligned = encodec.encode(synced_audio).audio_codes.squeeze(0).squeeze(0)
                    
                    # RESTORE DELAY for scoring (Inverse of realign)
                    tokens_roundtrip = add_delay_to_codebooks(tokens_roundtrip_aligned.unsqueeze(0)).squeeze(0)

                    # Get Original Tokens (Raw Delayed)
                    wm_tokens_orig = wm_tokens_th[idx] 

                    orig_greens, orig_scored = [], []
                    greens, scored = [], []

                    # --- ANCHOR STREAM LOGIC ---
                    anchor_tokens_orig = None
                    anchor_tokens_rt = None
                    
                    if wm_tokens_orig.shape[0] > 0:
                        anchor_tokens_orig = wm_tokens_orig[0, :]
                    
                    if tokens_roundtrip is not None and tokens_roundtrip.shape[0] > 0:
                        anchor_tokens_rt = tokens_roundtrip[0, :]

                    # --- CRITICAL FIX: LOCAL N-GRAM BUILDER ---
                    def get_shifted_ngrams(context_stream, n):
                        """
                        Manually constructs ngrams such that row `t` contains the history `[t-n : t]`.
                        Crucially, we SHIFT the context right by 1 first, so row `t` actually 
                        contains `[t-1-n : t-1]`.
                        This ensures Token[t] is hashed using History[0...t-1].
                        """
                        if n <= 0:
                            return torch.zeros((1, 0), device='cpu')
                        
                        # 1. Shift context right: [S0, S1...] -> [0, S0, S1...]
                        # This aligns index `t` with value `S_{t-1}`
                        pad_val = torch.zeros(1, dtype=context_stream.dtype, device=context_stream.device)
                        shifted_stream = torch.cat([pad_val, context_stream[:-1]])
                        
                        # 2. Build standard ngrams on the SHIFTED stream
                        # Using your existing build function, or manual sliding window
                        return build_stream_ngrams_from_full_stream(shifted_stream, n, device='cpu')

                    
                    # Iterate over streams
                    for s_idx, stream_id in enumerate(args.wm_streams):
                        stream_id = int(stream_id)
                        s_map = clustering_maps.get(stream_id) if clustering_maps else None
                        
                        # --- A. ORIGINAL ---
                        wm_stream = wm_tokens_orig[stream_id, :]
                        
                        # Select Context Source
                        if stream_id == 0:
                            raw_context = wm_stream
                        else:
                            raw_context = anchor_tokens_orig if anchor_tokens_orig is not None else wm_stream
                        
                        # Build SHIFTED NGrams
                        ngrams_orig = get_shifted_ngrams(raw_context, args.wm_ngram)

                        # Score Original
                        g_mask, s_mask = compute_watermark_scores(
                            wm_stream, ngrams_orig, 2048, args.wm_gamma, args.wm_seed, clustering_map=s_map
                        )
                        orig_greens.append((g_mask * s_mask).float().sum().item())
                        orig_scored.append(s_mask.float().sum().item())

                        # --- B. ROUNDTRIP ---
                        if tokens_roundtrip is not None and stream_id < tokens_roundtrip.shape[0]:
                            wm_stream_rt = tokens_roundtrip[stream_id, :]
                            
                            if stream_id == 0:
                                raw_context_rt = wm_stream_rt
                            else:
                                raw_context_rt = anchor_tokens_rt if anchor_tokens_rt is not None else wm_stream_rt
                            
                            ngrams_rt = get_shifted_ngrams(raw_context_rt, args.wm_ngram)
                            
                            # Score Roundtrip
                            g_mask_rt, s_mask_rt = compute_watermark_scores(
                                wm_stream_rt, ngrams_rt, 2048, args.wm_gamma, args.wm_seed, clustering_map=s_map
                            )
                            greens.append((g_mask_rt * s_mask_rt).float().sum().item())
                            scored.append(s_mask_rt.float().sum().item())
                        else:
                            greens.append(0)
                            scored.append(0)

                    # Calculate Stats
                    tot_orig_greens = float(sum(orig_greens))
                    tot_orig_scored = float(sum(orig_scored))
                    orig_pval = get_binomial_pval(tot_orig_greens, tot_orig_scored, args.wm_gamma)
                    
                    tot_greens = sum(greens)
                    tot_scored = sum(scored)
                    pval = get_binomial_pval(tot_greens, tot_scored, args.wm_gamma)
                    
                    global_idx = batch_start + idx
                    result = {
                        "config": config_name,
                        "idx": global_idx,
                        "aug_name": str(validation_aug),
                        "strength": strength,
                        "original_greens": orig_greens,
                        "original_ntoks": wm_tokens_orig.shape[-1],
                        "original_pval": orig_pval,
                        "greens": greens,
                        "scored": scored,
                        "ntoks": tokens_roundtrip.shape[-1],
                        "pval": pval,
                    }
                    global_watermark_results.append(result)

                    print(f"Orig P: {orig_pval:.4e}, RT P: {pval:.4e}")
                    
                    # Save generated audio
                    if args.save_audio > 0 and global_idx < args.save_audio:
                        audio_output_dir = os.path.join(args.output_dir, f"audio_{config_name}")
                        os.makedirs(audio_output_dir, exist_ok=True)
                        aug_audio_np = batch_aug_audio[idx, 0].detach().cpu().numpy().astype(np.float32)
                        sphn.write_wav(
                            os.path.join(audio_output_dir, f'{validation_aug}_{strength}_{global_idx:03d}.wav'),
                            aug_audio_np, encodec.config.sampling_rate
                        )

        # Save Text Prompts
        with open(os.path.join(args.output_dir, f"generated_texts_{config_name}.txt"), "a", encoding="utf-8") as f:
            for idx in range(batch_size):
                f.write(f"{idx + batch_start:04d},{batch_texts[idx]}\n")

    # Save summary
    summary = {'config': vars(args), 'results': global_watermark_results}
    torch.save(summary, os.path.join(args.output_dir, f'summary_{config_name}.pt'))

    # Calculate statistics
    df_data = [
        {
            "idx": wmr["idx"],
            "aug_name": wmr["aug_name"],
            "strength": str(wmr["strength"]),
            "greens": sum(wmr["greens"]),
            "scored": sum(wmr["scored"]),
            "ntoks": wmr["ntoks"],
            "pval": wmr["pval"],
            "logpval": -np.log10(wmr["pval"]) if wmr["pval"] is not None and wmr["pval"] > 0 else None,
        }
        for wmr in global_watermark_results
    ]

    df = pd.DataFrame(df_data)
    numeric_cols_for_mean = ["greens", "scored", "ntoks", "pval", "logpval"]
    cols_to_aggregate = [col for col in numeric_cols_for_mean if col in df.columns]
    
    mean_df = df.groupby(["aug_name", "strength"])[cols_to_aggregate].agg("mean")
    mean_df.to_csv(os.path.join(args.output_dir, f'summary_{config_name}.csv'))
        
    pd.set_option('display.max_rows', None)
    print(mean_df)
    df.to_csv(os.path.join(args.output_dir, f'results_{config_name}.csv'), index=False)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu")
    parser.add_argument("--seed", type=int, default=42424242)
    parser.add_argument("--steps", type=int, default=256)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--prompt_file", type=str, required=True, help="Path to txt file with prompts")
    parser.add_argument("--nsamples", type=int, default=-1)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--wm_method", type=str, default="maryland")
    parser.add_argument("--wm_streams", nargs='+', default=[0, 1, 2, 3], help="Stream indices")
    parser.add_argument("--wm_delta", type=float, default=2.0)
    parser.add_argument("--wm_gamma", type=float, default=0.25)
    parser.add_argument("--wm_ngram", type=int, default=0)
    parser.add_argument("--wm_seed", type=int, default=0)
    parser.add_argument("--wm_clustering", type=bool_inst, default=False)
    parser.add_argument("--encodec_weight", type=str, default=None)
    parser.add_argument("--save_audio", type=int, default=10)
    parser.add_argument("--eval_aug", type=bool_inst, default=True)         
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    seed_all(args.seed)

    # Determine what to run
    if args.encodec_weight is not None:
        configs_to_run = [{"method": None, "maps": None}] # Default standard
    else:
        configs_to_run = []

    configs_to_run = [{"method": None, "maps": None}] 

    clustering_select_path = "/home/wmar/wmar_audio/evals/encodec_configs_new.json"
    if args.wm_clustering:
        CLUSTERS_PKL = "/home/wmar/wmar_audio/models/embeddings/encodec_leiden_clusterings_trainonly_allparams.pkl"

        select_map = None
        if clustering_select_path:
            with open(clustering_select_path, "r", encoding="utf-8") as jf:
                select_map = json.load(jf)  # expect mapping channel_name -> "leiden_10_res1.0"

        # load maps (only loads requested entries if select_map provided)
        # Use simple list comprehension to get channel names
        all_maps = load_clustering_maps(CLUSTERS_PKL, select_map=select_map, device=args.device,
                                        vocab_size=2048, channel_names=[str(x) for x in args.wm_streams])

        # build a single config from the json selections
        current_config_maps = {}
        for s in args.wm_streams:
            s_str = str(s)
            s_int = int(s)
            
            # 1. Lookup config string using STRING key
            sel = select_map[s_str]  # "leiden_10_res1.0"
            left, res_part = sel.split("_res", 1)
            method, cnt_str = left.rsplit("_", 1)
            cnt = int(cnt_str)
            res = float(res_part)

            # 2. Retrieve map using INTEGER key
            _, tmap = all_maps[s_int][method][cnt]
            current_config_maps[s_int] = tmap

        configs_to_run.append({
            "method": "selected",
            "maps": current_config_maps
        })

    print(f"Starting execution for {len(configs_to_run)} configurations")
    print(configs_to_run)

    for config in configs_to_run:
        config_name = "standard" if config["method"] is None else "selected"
        run_watermark_eval(args, clustering_maps=config["maps"], config_name=config_name)

if __name__ == "__main__":
    main()
