from pathlib import Path
import shutil
import hashlib
import random
from typing import Tuple

import pandas as pd
import numpy as np


RNG_SEED = 20240919
VALID_INTONATIONS = {"neutral", "bored", "excited", "question"}
random.seed(RNG_SEED)


# ------------------------- helpers -------------------------

def _ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


def _link_or_copy(src: Path, dst: Path) -> None:
    _ensure_dir(dst.parent)
    try:
        if dst.exists() or dst.is_symlink():
            dst.unlink()
        # try hardlink for speed; fall back to copy
        dst.hardlink_to(src)
    except Exception:
        shutil.copyfile(src, dst)


def _make_id(stem: str) -> str:
    h = hashlib.sha1((stem + str(RNG_SEED)).encode("utf-8")).hexdigest()[:12]
    return f"snd_{h}"


def _canon_intonation(x) -> str:
    return str(x).strip().lower()


def _to_int(x):
    try:
        if pd.isna(x):
            return None
        if isinstance(x, str):
            s = x.strip()
            if s == "":
                return None
            return int(float(s))
        return int(x)
    except Exception:
        return None


# ------------------------- core -------------------------

def _load_attributes(raw: Path) -> tuple[pd.DataFrame, Path]:
    attr_csv = raw / "MLEndSND_Audio_Attributes.csv"
    audio_dir = raw / "MLEndSND_Public" / "MLEndSND_Public"

    assert attr_csv.is_file(), f"Attributes CSV not found at {attr_csv}"
    assert audio_dir.is_dir(), f"Audio directory not found at {audio_dir}"

    df = pd.read_csv(attr_csv)
    df.columns = [c.strip() for c in df.columns]
    required = {"Public filename", "Numeral", "Intonation", "Speaker"}
    missing = required - set(df.columns)
    assert not missing, f"Missing columns in attributes: {missing}"

    # file stems, paths, existence
    df["stem"] = df["Public filename"].astype(str).str.zfill(5)
    df["src_wav"] = df["stem"].apply(lambda s: audio_dir / f"{s}.wav")
    assert df["src_wav"].map(Path.is_file).all(), "Some referenced audio files are missing."

    # labels
    df["Intonation"] = df["Intonation"].map(_canon_intonation)
    df["Numeral"] = df["Numeral"].map(_to_int)
    assert df["Numeral"].notna().all(), "Invalid Numeral values present."
    df["Speaker"] = df["Speaker"].astype(int)

    return df, audio_dir


def _group_speaker_split(df: pd.DataFrame,
                         test_speaker_frac: float = 0.2,
                         min_test_size: int = 2000,
                         max_attempts: int = 200) -> Tuple[pd.DataFrame, pd.DataFrame]:
    rng = random.Random(RNG_SEED)
    speakers = sorted(df["Speaker"].unique())
    all_ints = set(df["Intonation"].unique())

    def ok(tr: pd.DataFrame, te: pd.DataFrame) -> bool:
        if len(te) < min_test_size or len(tr) < 2 * min_test_size:
            return False
        if set(tr["Intonation"].unique()) != all_ints:
            return False
        if set(te["Intonation"].unique()) != all_ints:
            return False
        # coverage of numerals in test
        nums_all = set(df["Numeral"].unique())
        nums_te = set(te["Numeral"].unique())
        if len(nums_te) < 0.7 * len(nums_all):
            return False
        # frequent numerals (>=10 overall) should appear in test
        vc = df["Numeral"].value_counts()
        frequent = set(vc[vc >= 10].index)
        if not frequent.issubset(nums_te):
            return False
        return True

    for _ in range(max_attempts):
        rng.shuffle(speakers)
        k = max(1, int(round(test_speaker_frac * len(speakers))))
        te_spk = set(speakers[:k])
        te = df[df["Speaker"].isin(te_spk)]
        tr = df[~df["Speaker"].isin(te_spk)]
        if ok(tr, te):
            return tr.copy(), te.copy()

    # fallback: last 20% speakers
    k = max(1, int(round(test_speaker_frac * len(speakers))))
    te_spk = set(speakers[-k:])
    te = df[df["Speaker"].isin(te_spk)]
    tr = df[~df["Speaker"].isin(te_spk)]
    assert set(tr["Intonation"].unique()) == all_ints, "Train missing some intonations"
    assert set(te["Intonation"].unique()) == all_ints, "Test missing some intonations"
    return tr.copy(), te.copy()


def _write_outputs(df_tr: pd.DataFrame, df_te: pd.DataFrame, public: Path, private: Path) -> None:
    # clean old
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    _ensure_dir(public)
    _ensure_dir(private)

    # ids and relative audio_path under public
    def _add_ids(df: pd.DataFrame, split: str) -> pd.DataFrame:
        d = df.copy()
        d["id"] = d["stem"].map(_make_id)
        d["audio_path"] = d["id"].map(lambda i: f"{split}_audio/{i}.wav")
        return d

    df_tr = _add_ids(df_tr, "train")
    df_te = _add_ids(df_te, "test")

    # check ids
    assert df_tr["id"].is_unique and df_te["id"].is_unique, "Duplicate ids"
    assert set(df_tr["id"]).isdisjoint(set(df_te["id"])), "Train/Test ids overlap"

    # link/copy audio
    train_audio_dir = public / "train_audio"
    test_audio_dir = public / "test_audio"
    _ensure_dir(train_audio_dir)
    _ensure_dir(test_audio_dir)

    for row in df_tr.itertuples(index=False):
        _link_or_copy(Path(row.src_wav), train_audio_dir / f"{row.id}.wav")
    for row in df_te.itertuples(index=False):
        _link_or_copy(Path(row.src_wav), test_audio_dir / f"{row.id}.wav")

    # CSVs
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    answer_csv = private / "test_answer.csv"
    sample_csv = public / "sample_submission.csv"

    tr_out = df_tr[["id", "audio_path", "Numeral", "Intonation", "Speaker"]].copy()
    te_out = df_te[["id", "audio_path"]].copy()
    ans_out = df_te[["id", "Numeral", "Intonation"]].copy()

    tr_out.to_csv(train_csv, index=False)
    te_out.to_csv(test_csv, index=False)
    ans_out.to_csv(answer_csv, index=False)

    # sample submission using empirical distributions
    from collections import Counter

    num_counts = Counter(tr_out["Numeral"].tolist())
    int_counts = Counter(tr_out["Intonation"].tolist())

    num_vals, num_weights = zip(*sorted(num_counts.items()))
    int_vals, int_weights = zip(*sorted(int_counts.items()))

    num_weights = np.array(num_weights, dtype=float)
    num_weights = num_weights / num_weights.sum()
    int_weights = np.array(int_weights, dtype=float)
    int_weights = int_weights / int_weights.sum()

    rng = np.random.default_rng(RNG_SEED)

    def draw(values, weights, n):
        return list(rng.choice(values, size=n, p=weights))

    sample = pd.DataFrame({"id": te_out["id"]})
    sample["Numeral"] = draw(list(num_vals), num_weights, len(sample))
    sample["Intonation"] = draw(list(int_vals), int_weights, len(sample))
    sample["Numeral"] = sample["Numeral"].astype(int)
    sample["Intonation"] = sample["Intonation"].astype(str)
    sample.to_csv(sample_csv, index=False)

    # sanity checks
    # files exist
    assert len(list(train_audio_dir.glob("*.wav"))) == len(tr_out)
    assert len(list(test_audio_dir.glob("*.wav"))) == len(te_out)

    # sets
    assert set(tr_out["Intonation"].unique()) == set(ans_out["Intonation"].unique()) == VALID_INTONATIONS

    # ids alignment
    assert set(te_out["id"]) == set(ans_out["id"]) and len(te_out) == len(ans_out)

    # sample submission format
    assert sample.columns.tolist() == ["id", "Numeral", "Intonation"], "sample_submission columns must be [id, Numeral, Intonation]"
    assert set(sample["id"]) == set(te_out["id"]) and sample.notna().all().all()

    # anonymized filenames
    import re
    pat = re.compile(r"^snd_[0-9a-f]{12}\.wav$")
    for p in list(train_audio_dir.glob("*.wav")) + list(test_audio_dir.glob("*.wav")):
        assert pat.match(p.name), f"Unexpected filename pattern: {p.name}"


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete preparation process.

    - Reads raw data from `raw/`.
    - Writes train.csv, test.csv, sample_submission.csv, audio under `public/`.
    - Writes test_answer.csv under `private/`.
    - Copies description.txt into `public/`.
    """
    df, _ = _load_attributes(raw)
    df_tr, df_te = _group_speaker_split(df)
    _write_outputs(df_tr, df_te, public, private)

    # copy description.txt from repo root to public/
    repo_desc = (Path(__file__).parent / "description.txt").resolve()
    if repo_desc.is_file():
        shutil.copyfile(repo_desc, public / "description.txt")
