import argparse
import os
import glob
import csv
from typing import List
import torch
import numpy as np
import torchaudio
from huggingface_hub import hf_hub_download
from collections import defaultdict

from models.moshi.models.loaders import get_mimi
from models.moshi.models import loaders
from training import get_validation_augs, get_dummy_augs


def list_audio_files(audio_dir: str, exts=("wav", "mp3", "ogg", "flac")) -> List[str]:
    files = []
    for ext in exts:
        files.extend(glob.glob(os.path.join(audio_dir, f"*/*/*.{ext}")))
    return sorted(files)


def resample_if_needed(waveform: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
    if src_sr == dst_sr:
        return waveform
    return torchaudio.transforms.Resample(src_sr, dst_sr)(waveform)


def single_channel_best_match(a: np.ndarray, b: np.ndarray) -> float:
    if a.size == 0 or b.size == 0:
        return 0.0
    if a.shape[0] == b.shape[0]:
        return float((a == b).mean())
    if a.shape[0] < b.shape[0]:
        a, b = b, a
    best = 0.0
    for shift in range(a.shape[0]):
        rolled = np.roll(a, shift)[: b.shape[0]]
        best = max(best, (rolled == b).mean())
    return float(best)


def per_channel_token_match(toks1: torch.LongTensor, toks2: torch.LongTensor):
    if toks1.dim() == 3:
        toks1 = toks1.squeeze(0)
    if toks2.dim() == 3:
        toks2 = toks2.squeeze(0)
    return [
        single_channel_best_match(
            toks1[ch].cpu().numpy().astype(np.int64),
            toks2[ch].cpu().numpy().astype(np.int64),
        )
        for ch in range(toks1.shape[0])
    ]


def save_tokens_txt(path: str, codes: torch.LongTensor):
    if codes.dim() == 3:
        codes = codes.squeeze(0)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        for ch in range(codes.shape[0]):
            f.write(" ".join(map(str, codes[ch].tolist())) + "\n")


def tokens_to_string(codes: torch.LongTensor) -> str:
    """
    Convert token tensor to a single-line string for logging/CSV.
    Multi-channel tokens are joined by " | ".
    """
    if codes.dim() == 3:
        codes = codes.squeeze(0)
    channel_strs = []
    for ch in range(codes.shape[0]):
        channel_strs.append(" ".join(map(str, codes[ch].tolist())))
    return " | ".join(channel_strs)


def safe_name(s: str) -> str:
    # Basic sanitization for filenames: replace spaces and slashes
    return s.replace(" ", "_").replace("/", "_").replace("\\", "_")


def run_pipeline(
    mimi_weight_ori: str,
    audio_dir: str,
    output_dir: str,
    nsamples: int = 0,
    batch_size: int = 1,
    eval_aug: bool = False
):
    per_channel_matches = defaultdict(list)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(output_dir, exist_ok=True)

    print("Loading MIMI...")
    mimi = get_mimi(mimi_weight_ori, device=device)

    audio_files = list_audio_files(audio_dir)
    if nsamples > 0:
        audio_files = audio_files[:nsamples]

    csv_rows = []

    for batch_start in range(0, len(audio_files), batch_size):
        batch_files = audio_files[batch_start: batch_start + batch_size]
        batch_wavs = []

        for p in batch_files:
            wav, sr = torchaudio.load(p)
            if wav.shape[0] > 1:
                wav = wav.mean(dim=0, keepdim=True)
            wav = resample_if_needed(wav, sr, mimi.sample_rate)
            batch_wavs.append(wav)

        if not batch_wavs:
            continue

        max_len = max(w.shape[-1] for w in batch_wavs)
        batch_wavs = [
            torch.nn.functional.pad(w, (0, max_len - w.shape[-1]))
            for w in batch_wavs
        ]
        batch_wavs = torch.stack(batch_wavs, dim=0)  # [B, 1, T]

        with torch.no_grad():
            # Get original tokens (reference)
            orig_tokens = mimi.encode(batch_wavs.to(device))  # [B, K, T]
            # Decode to get reconstructed audio (clean)
            decoded_audio_clean = mimi.decode(orig_tokens)    # [B, 1, S]

            # Prepare validation augmentations
            augs = get_validation_augs() if eval_aug else get_dummy_augs()
            for aug, _ in augs:
                aug.to(device)
            
            # Iterate over augmentations and strengths
            for validation_aug, strengths in augs:
                for strength in strengths:
                    # Apply augmentation to the decoded audio
                    batch_aug_audio, _ = validation_aug(decoded_audio_clean, None, strength)

                    with torch.no_grad():
                        # Re-encode the AUGMENTED audio to check stability
                        rt_tokens = mimi.encode(batch_aug_audio)  # [B, K, T_rt]

                    aug_name = validation_aug.__class__.__name__

                    for i, audio_path in enumerate(batch_files):
                        idx = batch_start + i

                        base = os.path.splitext(os.path.basename(audio_path))[0]
                        out_dir = os.path.join(output_dir, "samples", base)
                        os.makedirs(out_dir, exist_ok=True)

                        orig = orig_tokens[i]
                        rt = rt_tokens[i]

                        # Save tokens (orig is constant, rt changes per aug - write per-aug file)
                        save_tokens_txt(os.path.join(out_dir, "orig_tokens.txt"), orig)

                        # Save attacked/augmented tokens with aug name and strength to avoid overwrites
                        aug_safe = safe_name(aug_name)
                        strength_str = str(strength).replace(".", "_")
                        attacked_fname = f"attacked_tokens_{aug_safe}_str{strength_str}.txt"
                        save_tokens_txt(os.path.join(out_dir, attacked_fname), rt)

                        # Prepare token strings for CSV/logging
                        orig_tokens_str = tokens_to_string(orig)
                        attacked_tokens_str = tokens_to_string(rt)

                        rates = per_channel_token_match(orig, rt)
                        mean_match = float(np.mean(rates)) if rates else 0.0

                        for ch, rate in enumerate(rates):
                            per_channel_matches[ch].append(rate)

                        csv_rows.append({
                            "idx": idx,
                            "audio_file": os.path.basename(audio_path),
                            "aug_name": aug_name,
                            "aug_strength": strength,
                            "orig_len": orig.shape[-1],
                            "rt_len": rt.shape[-1],
                            "mean_match": mean_match,
                            "rcc_error": 1.0 - mean_match,
                            "per_channel": rates,
                            "orig_tokens": orig_tokens_str,
                            "attacked_tokens": attacked_tokens_str,
                        })

                        print(f"[{idx}] {os.path.basename(audio_path)} aug={aug_name} str={strength} mean_match={mean_match:.4f}")

    max_k = max((len(r["per_channel"]) for r in csv_rows), default=0)
    # Added aug columns to header and token columns
    fieldnames = ["idx", "audio_file", "aug_name", "aug_strength", "orig_len", "rt_len", "mean_match", "rcc_error", "orig_tokens", "attacked_tokens"] + \
                 [f"tm_rate_{i}" for i in range(max_k)]

    csv_path = os.path.join(output_dir, "rcc_summary.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for r in csv_rows:
            row = {
                "idx": r["idx"],
                "audio_file": r["audio_file"],
                "aug_name": r.get("aug_name", ""),
                "aug_strength": r.get("aug_strength", ""),
                "orig_len": r["orig_len"],
                "rt_len": r["rt_len"],
                "mean_match": f"{r['mean_match']:.4f}",
                "rcc_error": f"{r['rcc_error']:.4f}",
                # "orig_tokens": r.get("orig_tokens", ""),
                # "attacked_tokens": r.get("attacked_tokens", ""),
            }
            for i in range(max_k):
                row[f"tm_rate_{i}"] = (
                    f"{r['per_channel'][i]:.4f}" if i < len(r["per_channel"]) else ""
                )
            writer.writerow(row)

    print(f"Done. Summary CSV: {csv_path}")

    avg_match_path = os.path.join(output_dir, "avg_per_channel_match.txt")
    with open(avg_match_path, "w") as f:
        for ch in sorted(per_channel_matches.keys()):
            errs = per_channel_matches[ch]
            if len(errs) == 0:
                continue
            mean_err = float(np.mean(errs))
            f.write(f"channel_{ch}: {mean_err:.6f}\n")

    print(f"Per-channel matches: {avg_match_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--audio_dir", required=True)
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--mimi_weight_ori", default=None)
    parser.add_argument("--nsamples", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=1)
    # Added eval_aug flag
    parser.add_argument("--eval_aug", action="store_true", help="Enable validation augmentations")
    args = parser.parse_args()

    if args.mimi_weight_ori is None:
        args.mimi_weight_ori = hf_hub_download(
            "kyutai/moshiko-pytorch-bf16",
            loaders.MIMI_NAME,
        )

    run_pipeline(
        mimi_weight_ori=args.mimi_weight_ori,
        audio_dir=args.audio_dir,
        output_dir=args.output_dir,
        nsamples=args.nsamples,
        batch_size=args.batch_size,
        eval_aug=args.eval_aug
    )