import os
import re
import csv
import json
import random
import shutil
from pathlib import Path
from typing import List, Tuple

# Regex to extract the 4-digit year from filenames: e.g., "Some Title _1999_ - full transcript.txt"
YEAR_PATTERN = re.compile(r"_(\d{4})_")

# Bounds for sanity checks
MIN_YEAR = 1900
MAX_YEAR = 2026

TRAIN_FRACTION = 0.8
RANDOM_SEED = 42


def _find_docs(root: Path) -> List[Tuple[Path, int]]:
    assert root.is_dir(), f"Transcripts directory not found: {root}"
    entries: List[Tuple[Path, int]] = []
    for name in os.listdir(root):
        if not name.lower().endswith(".txt"):
            continue
        m = YEAR_PATTERN.search(name)
        if not m:
            continue
        year = int(m.group(1))
        if not (MIN_YEAR <= year <= MAX_YEAR):
            continue
        full_path = root / name
        try:
            if full_path.stat().st_size < 50:  # ignore near-empty files
                continue
        except OSError:
            continue
        entries.append((full_path, year))
    if not entries:
        raise RuntimeError("No valid transcript files found with parseable year in filenames.")
    return entries


def _decade(y: int) -> int:
    return (y // 10) * 10


def _stratified_split_by_decade(items: List[Tuple[Path, int]], train_frac: float, rng: random.Random) -> Tuple[List[Tuple[Path, int]], List[Tuple[Path, int]]]:
    buckets: dict[int, List[Tuple[Path, int]]] = {}
    for p, y in items:
        buckets.setdefault(_decade(y), []).append((p, y))
    train, test = [], []
    for _, lst in buckets.items():
        rng.shuffle(lst)
        n = len(lst)
        if n == 1:
            # Put the lone sample in train if possible; ensure at least one test overall later
            train.extend(lst)
            continue
        k = int(round(n * train_frac))
        k = min(max(k, 1), n - 1)
        train.extend(lst[:k])
        test.extend(lst[k:])
    # If edge-case where test ended empty due to tiny strata, move one item to test
    if len(test) == 0 and len(train) > 1:
        test.append(train.pop())
    rng.shuffle(train)
    rng.shuffle(test)
    return train, test


def _ensure_clean_dir(path: Path):
    if path.exists():
        shutil.rmtree(path)
    path.mkdir(parents=True, exist_ok=True)


def _write_csv(path: Path, rows: List[List]):
    with open(path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(rows)


def prepare(raw: Path, public: Path, private: Path):
    """
    Prepare the Movie Year Prediction dataset.

    Args:
        raw: Path to raw directory, expected to contain a `transcripts/` folder of .txt files.
        public: Path to public directory where train/test files and sample submission are written.
        private: Path to private directory where test answers are written.
    """
    rng = random.Random(RANDOM_SEED)

    transcripts_dir = raw / "transcripts"
    docs = _find_docs(transcripts_dir)

    rng.shuffle(docs)

    train_items, test_items = _stratified_split_by_decade(docs, TRAIN_FRACTION, rng)

    # Ensure output structure
    _ensure_clean_dir(public)
    _ensure_clean_dir(private)

    train_texts_dir = public / "train_texts"
    test_texts_dir = public / "test_texts"
    _ensure_clean_dir(train_texts_dir)
    _ensure_clean_dir(test_texts_dir)

    # Assign anonymized ids
    total = len(docs)
    id_width = max(6, len(str(total)))
    def make_id(i: int) -> str:
        return f"m{str(i).zfill(id_width)}"

    id_map: dict[Path, tuple[str, int]] = {}
    for idx, (p, y) in enumerate(train_items + test_items, start=1):
        id_map[p] = (make_id(idx), y)

    # Copy files
    for p, _y in train_items:
        anon_id, _ = id_map[p]
        shutil.copy2(p, train_texts_dir / f"{anon_id}.txt")
    for p, _y in test_items:
        anon_id, _ = id_map[p]
        shutil.copy2(p, test_texts_dir / f"{anon_id}.txt")

    # Build CSVs
    train_rows = [["id", "year"]]
    for p, y in train_items:
        anon_id, _ = id_map[p]
        train_rows.append([anon_id, y])

    test_rows = [["id"]]
    test_answer_rows = [["id", "year"]]
    for p, y in test_items:
        anon_id, _ = id_map[p]
        test_rows.append([anon_id])
        test_answer_rows.append([anon_id, y])

    _write_csv(public / "train.csv", train_rows)
    _write_csv(public / "test.csv", test_rows)
    _write_csv(private / "test_answer.csv", test_answer_rows)

    # Sample submission: generate plausible numeric years from train distribution
    train_years = [y for _, y in train_items]
    mu = sum(train_years) / len(train_years)
    # sample std dev; fallback if 0
    if len(train_years) > 1:
        var = sum((y - mu) ** 2 for y in train_years) / (len(train_years) - 1)
        sd = var ** 0.5
    else:
        sd = 5.0
    if sd == 0:
        sd = 5.0

    sample_rows = [["id", "year"]]
    for p, _y in test_items:
        anon_id, _ = id_map[p]
        noisy = int(round(rng.gauss(mu, sd)))
        noisy = max(MIN_YEAR, min(MAX_YEAR, noisy))
        sample_rows.append([anon_id, noisy])
    _write_csv(public / "sample_submission.csv", sample_rows)

    # Copy description.txt to public for participants
    root_desc = Path(__file__).with_name("description.txt")
    if root_desc.exists():
        shutil.copy2(root_desc, public / "description.txt")

    # Save metadata for debugging (in public for transparency about process, contains no labels)
    meta = {
        "random_seed": RANDOM_SEED,
        "train_fraction": TRAIN_FRACTION,
        "min_year": MIN_YEAR,
        "max_year": MAX_YEAR,
        "counts": {"train": len(train_items), "test": len(test_items), "total": len(docs)},
    }
    with open(public / "prep_meta.json", "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)

    # Checks
    assert train_texts_dir.exists(), "public/train_texts should exist"
    assert test_texts_dir.exists(), "public/test_texts should exist"

    n_train_files = len([n for n in os.listdir(train_texts_dir) if n.endswith(".txt")])
    n_test_files = len([n for n in os.listdir(test_texts_dir) if n.endswith(".txt")])
    assert n_train_files == len(train_rows) - 1, "Mismatch between train.csv and train_texts files count"
    assert n_test_files == len(test_rows) - 1, "Mismatch between test.csv and test_texts files count"

    # No overlap between train/test ids
    train_ids = {r[0] for r in train_rows[1:]}
    test_ids = {r[0] for r in test_rows[1:]}
    assert train_ids.isdisjoint(test_ids), "Train/Test ID overlap detected"

    # test_answer ids match test.csv ids
    ta_ids = {r[0] for r in test_answer_rows[1:]}
    assert ta_ids == test_ids, "test_answer.csv and test.csv ids mismatch"

    # Avoid year leakage in filenames
    for fn in os.listdir(train_texts_dir):
        assert YEAR_PATTERN.search(fn) is None, "Leakage: year-like token found in train filename"
    for fn in os.listdir(test_texts_dir):
        assert YEAR_PATTERN.search(fn) is None, "Leakage: year-like token found in test filename"

    # Year bounds in CSVs
    with open(public / "train.csv", newline="", encoding="utf-8") as f:
        rdr = csv.DictReader(f)
        for row in rdr:
            y = int(row["year"])
            assert MIN_YEAR <= y <= MAX_YEAR
    with open(private / "test_answer.csv", newline="", encoding="utf-8") as f:
        rdr = csv.DictReader(f)
        for row in rdr:
            y = int(row["year"])
            assert MIN_YEAR <= y <= MAX_YEAR