from pathlib import Path
import os
import csv
import random
import shutil
from typing import List, Tuple, Dict

# Reproducibility
SEED = 42
random.seed(SEED)

# Fixed non-overlapping segment length in seconds
SEGMENT_SECONDS = 10

CLASSES = list(range(0, 7))  # 0 = no-cheat, 1..6 from gt.txt


# ----------------------------- helpers -----------------------------

def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


def _write_csv(path: Path, header: List[str], rows: List[List]):
    _ensure_dir(path.parent)
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(header)
        writer.writerows(rows)


def _mmss_to_seconds(token: str) -> int:
    token = token.strip()
    if not token:
        raise ValueError("Empty time token")
    try:
        v = int(token)
    except ValueError:
        digits = "".join(ch for ch in token if ch.isdigit())
        v = int(digits)
    mm, ss = divmod(v, 100)
    if not (0 <= ss < 60):
        raise ValueError(f"Invalid seconds field in mmss token: {token}")
    return mm * 60 + ss


def _read_subject_intervals(subject_dir: Path) -> Tuple[List[Tuple[int, int, int]], int]:
    """Return intervals list [(start_sec, end_sec_inclusive, label)] and max end time in seconds."""
    gt_path = subject_dir / "gt.txt"
    if not gt_path.exists():
        raise FileNotFoundError(f"Missing gt.txt for subject at {subject_dir}")
    intervals: List[Tuple[int, int, int]] = []
    max_end = 0
    with gt_path.open("r", encoding="utf-8", errors="ignore") as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue
            parts = [p for p in line.replace("\t", " ").split(" ") if p]
            if len(parts) < 3:
                continue
            s_str, e_str, lbl_str = parts[0], parts[1], parts[2]
            s = _mmss_to_seconds(s_str)
            e = _mmss_to_seconds(e_str)
            if e < s:
                s, e = e, s
            try:
                lbl = int(lbl_str)
            except ValueError:
                lbl = int("".join(ch for ch in lbl_str if ch.isdigit()))
            if not (1 <= lbl <= 6):
                continue
            intervals.append((s, e, lbl))
            if e > max_end:
                max_end = e
    intervals.sort(key=lambda x: (x[0], x[1]))
    return intervals, max_end


def _segment_subject(max_end_sec: int) -> List[Tuple[int, int]]:
    segments: List[Tuple[int, int]] = []
    cur = 0
    total = max_end_sec + 1  # inclusive end converted to half-open
    while cur < total:
        seg_end = min(cur + SEGMENT_SECONDS, total)
        segments.append((cur, seg_end))  # [cur, seg_end)
        cur = seg_end
    return segments


def _assign_label_to_segment(seg: Tuple[int, int], intervals: List[Tuple[int, int, int]]) -> int:
    seg_s, seg_e = seg
    best_lbl = 0
    best_overlap = 0
    for s, e, lbl in intervals:
        ov = max(0, min(seg_e, e + 1) - max(seg_s, s))  # convert [s,e] to [s,e+1)
        if ov > best_overlap or (ov == best_overlap and ov > 0 and lbl < best_lbl):
            best_overlap = ov
            best_lbl = lbl
    return best_lbl


def _collect_subject_dirs(base_input_dir: Path) -> List[Path]:
    subs: List[Path] = []
    for name in sorted(os.listdir(base_input_dir)):
        p = base_input_dir / name
        if p.is_dir() and name.lower().startswith("subject"):
            wavs = [fn for fn in os.listdir(p) if fn.lower().endswith(".wav")]
            avis = [fn for fn in os.listdir(p) if fn.lower().endswith(".avi")]
            if len(wavs) >= 1 and len(avis) >= 2 and (p / "gt.txt").exists():
                subs.append(p)
    if len(subs) < 10:
        # Be lenient in tests, but still ensure dataset seems valid
        raise AssertionError(f"Expected many subjects, found only {len(subs)} in {base_input_dir}")
    return subs


def _link_or_copy(src: Path, dst: Path):
    _ensure_dir(dst.parent)
    if dst.exists():
        dst.unlink()
    try:
        os.link(str(src), str(dst))
        return "link"
    except Exception:
        shutil.copy2(str(src), str(dst))
        return "copy"


# ------------------------- public API -------------------------

def prepare(raw: Path, public: Path, private: Path):
    """
    Prepare the OEP dataset into competition format.

    Inputs
    - raw: absolute path to the directory containing the extracted original data (must include 'OEP database')
    - public: absolute path to the directory where public files will be written
    - private: absolute path to the directory where private files (test_answer.csv) will be written

    Outputs (created inside 'public' and 'private')
    - public/train.csv (subject_id, segment_start, segment_end, label)
    - public/test.csv (subject_id, segment_start, segment_end)
    - public/sample_submission.csv (subject_id, segment_start, segment_end, label)
    - public/train/media/, public/test/media/ with anonymized media files
    - private/test_answer.csv (subject_id, segment_start, segment_end, label)
    """
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Paths must be absolute"
    input_root = raw / "OEP database"
    assert input_root.exists(), f"Input directory not found: {input_root}"

    # Clean/create output dirs
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    _ensure_dir(public)
    _ensure_dir(private)

    # Discover subjects
    subjects = _collect_subject_dirs(input_root)

    # Build per-subject metadata
    meta = []
    for sub_path in subjects:
        sub_name = sub_path.name  # e.g., subject1
        intervals, max_end = _read_subject_intervals(sub_path)
        segments = _segment_subject(max_end)
        wavs = sorted([fn for fn in os.listdir(sub_path) if fn.lower().endswith(".wav")])
        avis = sorted([fn for fn in os.listdir(sub_path) if fn.lower().endswith(".avi")])
        webcam = None
        wearcam = None
        for fn in avis:
            if fn.endswith("1.avi"):
                webcam = fn
            elif fn.endswith("2.avi"):
                wearcam = fn
        if webcam is None and avis:
            webcam = avis[0]
        if wearcam is None and len(avis) > 1:
            wearcam = [a for a in avis if a != webcam][0]
        assert wavs, f"No wav found in {sub_path}"
        assert webcam and wearcam, f"Expect two avi files in {sub_path}"
        meta.append(
            {
                "name": sub_name,
                "path": sub_path,
                "intervals": intervals,
                "segments": segments,
                "audio": wavs[0],
                "webcam": webcam,
                "wearcam": wearcam,
            }
        )

    # Split into train/test by subjects
    rng = random.Random(SEED)
    rng.shuffle(meta)
    n = len(meta)
    n_test = max(6, n // 4)
    test_meta = meta[:n_test]
    train_meta = meta[n_test:]

    assert set(m["name"] for m in train_meta).isdisjoint(set(m["name"] for m in test_meta))
    assert len(train_meta) + len(test_meta) == n

    # Output dirs
    train_dir = public / "train"
    test_dir = public / "test"
    train_media_dir = train_dir / "media"
    test_media_dir = test_dir / "media"
    _ensure_dir(train_media_dir)
    _ensure_dir(test_media_dir)

    # CSV rows
    train_rows: List[List] = []
    test_rows: List[List] = []
    test_answer_rows: List[List] = []

    def process_split(split_meta: List[Dict], media_root: Path, is_train: bool):
        for m in split_meta:
            sid = m["name"]
            # link media with anonymized names
            audio_src = m["path"] / m["audio"]
            webcam_src = m["path"] / m["webcam"]
            wearcam_src = m["path"] / m["wearcam"]
            audio_dst = media_root / f"{sid}_audio.wav"
            webcam_dst = media_root / f"{sid}_webcam.avi"
            wearcam_dst = media_root / f"{sid}_wearcam.avi"
            _link_or_copy(audio_src, audio_dst)
            _link_or_copy(webcam_src, webcam_dst)
            _link_or_copy(wearcam_src, wearcam_dst)

            for seg in m["segments"]:
                lbl = _assign_label_to_segment(seg, m["intervals"])
                if is_train:
                    train_rows.append([sid, seg[0], seg[1], lbl])
                else:
                    test_rows.append([sid, seg[0], seg[1]])
                    test_answer_rows.append([sid, seg[0], seg[1], lbl])

    process_split(train_meta, train_media_dir, is_train=True)
    process_split(test_meta, test_media_dir, is_train=False)

    # Sort for determinism
    train_rows.sort(key=lambda r: (r[0], r[1], r[2]))
    test_rows.sort(key=lambda r: (r[0], r[1], r[2]))
    test_answer_rows.sort(key=lambda r: (r[0], r[1], r[2]))

    # Write CSVs to public/private
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_sub_csv = public / "sample_submission.csv"

    _write_csv(train_csv, ["subject_id", "segment_start", "segment_end", "label"], train_rows)
    _write_csv(test_csv, ["subject_id", "segment_start", "segment_end"], test_rows)
    _write_csv(test_answer_csv, ["subject_id", "segment_start", "segment_end", "label"], test_answer_rows)

    # Sample submission with random valid labels drawn from observed labels
    unique_labels = sorted({row[3] for row in test_answer_rows}) or CLASSES
    sample_rows = [r + [random.choice(unique_labels)] for r in test_rows]
    _write_csv(sample_sub_csv, ["subject_id", "segment_start", "segment_end", "label"], sample_rows)

    # Copy description.txt into public
    desc_src = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "description.txt")))
    if desc_src.exists():
        shutil.copy2(str(desc_src), str(public / "description.txt"))

    # Checks
    # 1) Ensure each test row has an answer
    tkeys = {(r[0], r[1], r[2]) for r in test_rows}
    akeys = {(r[0], r[1], r[2]) for r in test_answer_rows}
    assert tkeys == akeys, "Mismatch between test.csv and test_answer.csv keys"

    # 2) Labels in train within range and contain at least class 0
    train_labels = [r[3] for r in train_rows]
    assert all((isinstance(v, int) and 0 <= v <= 6) for v in train_labels)
    assert 0 in train_labels, "Train set must include no-cheat class (0)"
    assert len(set(train_labels)) >= 3, "Train set should include multiple classes for learning"

    # 3) Media counts
    def _count_media(root: Path) -> int:
        return len([fn for fn in os.listdir(root) if fn.endswith('.wav') or fn.endswith('.avi')])
    assert _count_media(train_media_dir) == 3 * len(train_meta)
    assert _count_media(test_media_dir) == 3 * len(test_meta)

    # 4) gt.txt not leaked
    for d in [train_dir, test_dir]:
        for dirpath, dirnames, filenames in os.walk(d):
            assert "gt.txt" not in filenames

    # 5) sample submission keys match test
    with sample_sub_csv.open("r", encoding="utf-8") as f:
        rdr = csv.DictReader(f)
        sample_keys = {(r["subject_id"], int(r["segment_start"]), int(r["segment_end"])) for r in rdr}
    assert sample_keys == tkeys, "Sample submission keys mismatch test.csv"

    # 6) public/private directories must exist
    assert public.exists() and private.exists()
