from __future__ import annotations

import hashlib
import os
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd


@dataclass(frozen=True)
class _Cfg:
    salt: str = "DEAM-Kaggle-Design-2025-09-13"
    seed: int = 20250913


CFG = _Cfg()


def _sha1_name(song_id: int, kind: str, ext: str) -> str:
    h = hashlib.sha1(f"{CFG.salt}:{kind}:{song_id}".encode("utf-8")).hexdigest()[:16]
    prefix = "a_" if kind == "audio" else "f_"
    return f"{prefix}{h}{ext}"


def _read_annotations(raw: Path) -> pd.DataFrame:
    ann1 = raw / "DEAM_Annotations" / "annotations" / "annotations averaged per song" / "song_level" / "static_annotations_averaged_songs_1_2000.csv"
    ann2 = raw / "DEAM_Annotations" / "annotations" / "annotations averaged per song" / "song_level" / "static_annotations_averaged_songs_2000_2058.csv"
    assert ann1.is_file(), f"Missing annotations file: {ann1}"
    assert ann2.is_file(), f"Missing annotations file: {ann2}"
    df1 = pd.read_csv(ann1)
    df2 = pd.read_csv(ann2)
    df1.columns = [c.strip() for c in df1.columns]
    df2.columns = [c.strip() for c in df2.columns]
    df = pd.concat([df1, df2], ignore_index=True)

    expected_cols = {"song_id", "valence_mean", "valence_std", "arousal_mean", "arousal_std"}
    missing = expected_cols - set(df.columns)
    assert not missing, f"Annotation columns missing: {missing}"

    df = df.drop_duplicates(subset=["song_id"]).copy()
    for c in ["song_id", "valence_mean", "valence_std", "arousal_mean", "arousal_std"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=["song_id", "valence_mean", "arousal_mean"]).copy()
    df["song_id"] = df["song_id"].astype(int)
    for c in ["valence_mean", "arousal_mean", "valence_std", "arousal_std"]:
        df[c] = df[c].clip(1.0, 9.0)
    return df


def _available_media_ids(raw: Path) -> Tuple[set[int], set[int]]:
    audio_dir = raw / "DEAM_audio" / "MEMD_audio"
    feat_dir = raw / "features" / "features"
    assert audio_dir.is_dir(), f"Missing audio dir: {audio_dir}"
    assert feat_dir.is_dir(), f"Missing features dir: {feat_dir}"

    audio_ids: set[int] = set()
    feat_ids: set[int] = set()

    for fn in os.listdir(audio_dir):
        if fn.lower().endswith(".mp3"):
            stem = Path(fn).stem
            if stem.isdigit():
                audio_ids.add(int(stem))
    for fn in os.listdir(feat_dir):
        if fn.lower().endswith(".csv"):
            stem = Path(fn).stem
            if stem.isdigit():
                feat_ids.add(int(stem))

    assert len(audio_ids) > 0, "No audio files detected"
    assert len(feat_ids) > 0, "No feature files detected"
    return audio_ids, feat_ids


def _stratified_split(df: pd.DataFrame, test_size: float = 0.2, seed: int = CFG.seed) -> Tuple[pd.DataFrame, pd.DataFrame]:
    n_bins = 5

    def make_bins(series: pd.Series, n: int) -> pd.Series:
        try:
            return pd.qcut(series, q=n, duplicates="drop")
        except Exception:
            return pd.cut(series, bins=n)

    v_bins = make_bins(df["valence_mean"], n_bins)
    a_bins = make_bins(df["arousal_mean"], n_bins)
    df = df.copy()
    df["strata"] = v_bins.astype(str) + "|" + a_bins.astype(str)

    rng = np.random.default_rng(seed)
    test_indices: list[int] = []
    for _, grp in df.groupby("strata"):
        idx = grp.index.to_list()
        n = len(idx)
        n_test = max(1, int(round(test_size * n)))
        rng.shuffle(idx)
        test_indices.extend(idx[:n_test])

    test_df = df.loc[sorted(set(test_indices))].copy()
    train_df = df.drop(index=test_df.index).copy()

    if len(train_df) == 0 or len(test_df) == 0:
        # Fallback to random split
        idx_all = df.index.to_list()
        rng.shuffle(idx_all)
        n_test = int(round(test_size * len(idx_all)))
        test_df = df.loc[idx_all[:n_test]].copy()
        train_df = df.loc[idx_all[n_test:]].copy()

    return (
        train_df.sort_values("song_id").reset_index(drop=True),
        test_df.sort_values("song_id").reset_index(drop=True),
    )


def _ensure_dirs(public: Path):
    (public / "train_audio").mkdir(parents=True, exist_ok=True)
    (public / "test_audio").mkdir(parents=True, exist_ok=True)
    (public / "train_features").mkdir(parents=True, exist_ok=True)
    (public / "test_features").mkdir(parents=True, exist_ok=True)


def _copy_and_rename(raw: Path, public: Path, train_df: pd.DataFrame, test_df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
    _ensure_dirs(public)
    audio_dir = raw / "DEAM_audio" / "MEMD_audio"
    feat_dir = raw / "features" / "features"

    outputs: dict[str, pd.DataFrame] = {}
    for split, df_in in [("train", train_df), ("test", test_df)]:
        rows = []
        for _, r in df_in.iterrows():
            sid = int(r["song_id"])
            audio_src = audio_dir / f"{sid}.mp3"
            feat_src = feat_dir / f"{sid}.csv"
            assert audio_src.is_file(), f"Missing audio for song_id={sid}: {audio_src}"
            assert feat_src.is_file(), f"Missing features for song_id={sid}: {feat_src}"

            audio_new = _sha1_name(sid, "audio", ".mp3")
            feat_new = _sha1_name(sid, "feat", ".csv")

            audio_dst = public / ("train_audio" if split == "train" else "test_audio") / audio_new
            feat_dst = public / ("train_features" if split == "train" else "test_features") / feat_new

            if not audio_dst.exists():
                shutil.copy2(audio_src, audio_dst)
            if not feat_dst.exists():
                shutil.copy2(feat_src, feat_dst)

            rows.append(
                {
                    "song_id": sid,
                    "audio_file": audio_dst.name,
                    "feature_file": feat_dst.name,
                    "valence_mean": float(r["valence_mean"]),
                    "arousal_mean": float(r["arousal_mean"]),
                }
            )
        outputs[split] = pd.DataFrame(rows).sort_values("song_id").reset_index(drop=True)

    return outputs["train"], outputs["test"]


def _write_csvs(public: Path, private: Path, train_df: pd.DataFrame, test_df: pd.DataFrame):
    # public train.csv with labels; public test.csv without labels
    train_cols = ["song_id", "audio_file", "feature_file", "valence_mean", "arousal_mean"]
    test_cols = ["song_id", "audio_file", "feature_file"]

    train_df[train_cols].to_csv(public / "train.csv", index=False)
    test_df[test_cols].to_csv(public / "test.csv", index=False)

    # private answer
    ans_cols = ["song_id", "valence_mean", "arousal_mean"]
    private.mkdir(parents=True, exist_ok=True)
    test_df[ans_cols].to_csv(private / "test_answer.csv", index=False)


def _write_sample_submission(public: Path, private: Path):
    # create a plausible sample submission using ground-truth mean + small noise
    ans = pd.read_csv(private / "test_answer.csv")
    rng = np.random.default_rng(CFG.seed)
    vm = float(ans["valence_mean"].mean())
    am = float(ans["arousal_mean"].mean())
    sample = ans.copy()
    sample["valence_mean"] = np.clip(rng.normal(loc=vm, scale=1.0, size=len(sample)), 1.0, 9.0)
    sample["arousal_mean"] = np.clip(rng.normal(loc=am, scale=1.0, size=len(sample)), 1.0, 9.0)
    sample = sample.round({"valence_mean": 4, "arousal_mean": 4})
    sample.to_csv(public / "sample_submission.csv", index=False)


def _validate_outputs(public: Path, private: Path, train_df: pd.DataFrame, test_df: pd.DataFrame):
    # Paths
    assert (public / "train.csv").is_file(), "public/train.csv missing"
    assert (public / "test.csv").is_file(), "public/test.csv missing"
    assert (private / "test_answer.csv").is_file(), "private/test_answer.csv missing"
    assert (public / "sample_submission.csv").is_file(), "public/sample_submission.csv missing"

    # No overlap
    assert set(train_df["song_id"]).isdisjoint(test_df["song_id"]), "train/test ids should not overlap"

    # check files exist and anonymized names
    for split, df in [("train", train_df), ("test", test_df)]:
        a_dir = public / ("train_audio" if split == "train" else "test_audio")
        f_dir = public / ("train_features" if split == "train" else "test_features")
        for _, r in df.iterrows():
            af = a_dir / r["audio_file"]
            ff = f_dir / r["feature_file"]
            assert af.is_file(), f"Missing audio file: {af}"
            assert ff.is_file(), f"Missing feature file: {ff}"
            assert r["audio_file"].startswith("a_") and r["audio_file"].endswith(".mp3"), "Unexpected audio filename"
            assert r["feature_file"].startswith("f_") and r["feature_file"].endswith(".csv"), "Unexpected feature filename"
            assert not Path(r["audio_file"]).stem.isdigit(), "Filename should not leak id"
            assert not Path(r["feature_file"]).stem.isdigit(), "Filename should not leak id"

    # ensure csv columns
    for c in ["song_id", "audio_file", "feature_file", "valence_mean", "arousal_mean"]:
        assert c in train_df.columns
    for c in ["song_id", "audio_file", "feature_file"]:
        assert c in test_df.columns

    # range checks
    for c in ["valence_mean", "arousal_mean"]:
        assert train_df[c].between(1.0, 9.0).all()
        assert test_df[c].between(1.0, 9.0).all()

    # answer <-> test id match
    ans = pd.read_csv(private / "test_answer.csv")
    te = pd.read_csv(public / "test.csv")
    assert set(ans["song_id"]) == set(te["song_id"]) and len(ans) == len(te)


def prepare(raw: Path, public: Path, private: Path):
    # Deterministic behavior
    np.random.seed(CFG.seed)

    raw = Path(raw).resolve()
    public = Path(public).resolve()
    private = Path(private).resolve()

    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Read
    df_ann = _read_annotations(raw)
    audio_ids, feat_ids = _available_media_ids(raw)
    avail_ids = audio_ids.intersection(feat_ids)
    df = df_ann[df_ann["song_id"].isin(avail_ids)].copy()
    assert len(df) >= 100, f"Too few usable items after intersection: {len(df)}"

    # Split
    train_df, test_df = _stratified_split(df, test_size=0.2, seed=CFG.seed)

    # Copy media and anonymize
    train_out, test_out = _copy_and_rename(raw, public, train_df, test_df)

    # Write csvs and sample
    _write_csvs(public, private, train_out, test_out)
    _write_sample_submission(public, private)


    # Validate
    _validate_outputs(public, private, train_out, test_out)

    # No return; full process completed