from pathlib import Path
import re
import csv
import uuid
import random
import shutil
from collections import defaultdict
from typing import List, Tuple

# Constants
SEED = 2025
random.seed(SEED)

SURAH_DIR_RE = re.compile(r"^surah_(\d{3})$")
MP3_RE = re.compile(r"^(\d{3})(\d{3})\.mp3$", re.IGNORECASE)
VALID_SURAH_IDS = set(range(1, 115))  # 1..114 inclusive


def _list_all_mp3s(raw: Path) -> List[Tuple[Path, int]]:
    items: List[Tuple[Path, int]] = []
    for folder in sorted(raw.iterdir()):
        if not folder.is_dir():
            continue
        m = SURAH_DIR_RE.match(folder.name)
        if not m:
            continue
        surah_id = int(m.group(1))
        if surah_id not in VALID_SURAH_IDS:
            continue
        for f in sorted(folder.iterdir()):
            if not f.is_file() or f.suffix.lower() != ".mp3":
                continue
            # Optionally validate filename pattern; ignore mismatches
            if not MP3_RE.match(f.name):
                continue
            items.append((f, surah_id))
    return items


def _stratified_split(items: List[Tuple[Path, int]], test_ratio: float = 0.2):
    by_class = defaultdict(list)
    for path, label in items:
        by_class[label].append(path)

    for k in by_class:
        random.shuffle(by_class[k])

    train, test = [], []
    for label, paths in by_class.items():
        n = len(paths)
        # ensure at least 1 test and 1 train sample per class when possible
        t = 1 if n == 1 else max(1, min(n - 1, int(round(n * test_ratio))))
        test_paths = paths[:t]
        train_paths = paths[t:]
        train.extend([(p, label) for p in train_paths])
        test.extend([(p, label) for p in test_paths])

    random.shuffle(train)
    random.shuffle(test)
    return train, test


def _generate_id() -> str:
    return f"qa_{uuid.uuid4().hex[:12]}"


def _write_csv(path: Path, rows, header: List[str]):
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(header)
        for r in rows:
            writer.writerow(r)


def _copy_description_to_public(project_root: Path, public: Path):
    desc_src = project_root / "description.txt"
    if desc_src.exists():
        public_desc = public / "description.txt"
        public_desc.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(desc_src, public_desc)


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete data preparation process.

    Inputs:
    - raw: Path to the raw data directory that contains surah_XXX folders with .mp3 files
    - public: Path to the public directory to write train/test data and sample_submission
    - private: Path to the private directory to write the hidden test_answer.csv

    This function will create the following files:
    - public/train.csv with columns ['id','audio_path','surah_id']
    - public/test.csv with columns ['id','audio_path']
    - public/sample_submission.csv with columns ['id','surah_id']
    - private/test_answer.csv with columns ['id','surah_id']

    Audio files will be copied to:
    - public/train/audio/<id>.mp3
    - public/test/audio/<id>.mp3
    """

    project_root = Path(__file__).parent.resolve()

    # Clean and create target directories
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    (public / "train" / "audio").mkdir(parents=True, exist_ok=True)
    (public / "test" / "audio").mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # List all mp3s from raw
    items = _list_all_mp3s(raw)
    assert len(items) > 0, f"No mp3 files found under {raw}"

    # Sanity: only valid labels
    assert all(label in VALID_SURAH_IDS for _, label in items), "Found invalid surah id"

    # Stratified split
    train_items, test_items = _stratified_split(items, test_ratio=0.2)

    # Create anonymized ids and copy files
    used_ids = set()
    train_rows = []  # id, audio_path, surah_id
    test_rows = []   # id, audio_path
    ans_rows = []    # id, surah_id

    def new_unique_id():
        for _ in range(1_000_000):
            i = _generate_id()
            if i not in used_ids:
                used_ids.add(i)
                return i
        raise RuntimeError("Failed to generate unique id")

    # Copy train files
    for src, label in train_items:
        _id = new_unique_id()
        dst = public / "train" / "audio" / f"{_id}.mp3"
        shutil.copy2(src, dst)
        rel_path = dst.relative_to(public).as_posix()
        train_rows.append((_id, rel_path, int(label)))

    # Copy test files and create answers
    for src, label in test_items:
        _id = new_unique_id()
        dst = public / "test" / "audio" / f"{_id}.mp3"
        shutil.copy2(src, dst)
        rel_path = dst.relative_to(public).as_posix()
        test_rows.append((_id, rel_path))
        ans_rows.append((_id, int(label)))

    # Write CSVs
    _write_csv(public / "train.csv", train_rows, header=["id", "audio_path", "surah_id"])
    _write_csv(public / "test.csv", test_rows, header=["id", "audio_path"])
    _write_csv(private / "test_answer.csv", ans_rows, header=["id", "surah_id"])

    # Sample submission using random labels from valid range
    labels_pool = sorted({row[2] for row in train_rows})
    if not labels_pool:
        labels_pool = list(VALID_SURAH_IDS)
    sample_rows = []
    for tid, _ in test_rows:
        sample_rows.append((tid, int(random.choice(labels_pool))))
    _write_csv(public / "sample_submission.csv", sample_rows, header=["id", "surah_id"])

    # Copy description.txt to public (if exists)
    _copy_description_to_public(project_root, public)

    # Checks
    train_ids = {r[0] for r in train_rows}
    test_ids = {r[0] for r in test_rows}

    # 1) No overlap between train and test ids
    assert train_ids.isdisjoint(test_ids), "Train and test ids overlap"

    # 2) Paths exist and point to the right folders
    for _, p, _ in train_rows:
        ap = public / p
        assert ap.is_file(), f"Missing train audio file: {ap}"
        assert ap.parent.as_posix().endswith("train/audio"), f"Train audio not in train/audio: {ap}"
    for _, p in test_rows:
        ap = public / p
        assert ap.is_file(), f"Missing test audio file: {ap}"
        assert ap.parent.as_posix().endswith("test/audio"), f"Test audio not in test/audio: {ap}"

    # 3) CSV lengths and total counts match
    n_train = len(train_rows)
    n_test = len(test_rows)
    total = len(items)
    assert n_train + n_test == total, "Train+Test count mismatch with total files"

    # 4) Label set and ranges
    train_labels = [r[2] for r in train_rows]
    test_labels = [r[1] for r in ans_rows]
    assert all((1 <= x <= 114) for x in train_labels), "Train labels out of range"
    assert all((1 <= x <= 114) for x in test_labels), "Test labels out of range"

    # 5) Each class appears in both splits when possible
    per_class_counts = defaultdict(int)
    per_class_train = defaultdict(int)
    per_class_test = defaultdict(int)
    for _, lbl in items:
        per_class_counts[lbl] += 1
    for lbl in train_labels:
        per_class_train[lbl] += 1
    for lbl in test_labels:
        per_class_test[lbl] += 1
    for lbl, c in per_class_counts.items():
        if c > 1:
            assert per_class_train[lbl] >= 1, f"Class {lbl} missing in train"
            assert per_class_test[lbl] >= 1, f"Class {lbl} missing in test"

    # 6) ID uniqueness and alignment
    assert len(train_ids) == n_train, "Duplicate ids in train.csv"
    assert len(test_ids) == n_test, "Duplicate ids in test.csv"
    ans_ids = {r[0] for r in ans_rows}
    assert ans_ids == test_ids, "test_answer.csv ids must match test.csv ids exactly"

    # 7) Sample submission alignment
    with (public / "sample_submission.csv").open("r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        subs = list(reader)
    assert set(r["id"] for r in subs) == test_ids, "sample_submission ids mismatch test ids"
    assert all(str(r["surah_id"]).isdigit() for r in subs), "sample_submission must contain integer labels"

    # 8) Ensure public description exists (if available at root it was copied)
    assert (public / "train.csv").is_file(), "public/train.csv should exist"
    assert (public / "test.csv").is_file(), "public/test.csv should exist"
    assert (private / "test_answer.csv").is_file(), "private/test_answer.csv should exist"

    # Print summary (optional for debugging)
    print("Preparation complete:")
    print(f"  Train: {n_train} samples in {public / 'train' / 'audio'}")
    print(f"  Test:  {n_test} samples in {public / 'test' / 'audio'}")
    print(f"  Wrote: {public / 'train.csv'}, {public / 'test.csv'}, {private / 'test_answer.csv'}, {public / 'sample_submission.csv'}")
