from pathlib import Path
import csv
import os
import shutil
import random
from collections import defaultdict, Counter


# Deterministic seed for all stochastic operations
RANDOM_SEED = 1337
TEST_FRACTION = 0.2  # 80/20 split


def _safe_float(x: str) -> float:
    try:
        v = float(x)
    except Exception as e:
        raise ValueError(f"Cannot parse float from value: {x}") from e
    if v != v:
        raise ValueError("Encountered NaN in labels")
    if v in (float("inf"), float("-inf")):
        raise ValueError("Encountered infinite value in labels")
    return v


def _parse_group_from_filename(fname: str) -> str:
    # Expected prefixes: AF, AM, CF, CM
    prefix = fname[:2]
    if prefix not in {"AF", "AM", "CF", "CM"}:
        raise ValueError(f"Unexpected filename prefix for group parsing: {fname}")
    return prefix


def _load_labels(labels_path: Path) -> list[dict]:
    data: list[dict] = []
    with labels_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) < 2:
                raise ValueError(f"Malformed line in labels.txt: {line}")
            fname, score_str = parts[0], parts[1]
            score = _safe_float(score_str)
            if not (1.0 <= score <= 5.0):
                raise ValueError(f"Label out of range [1,5] for {fname}: {score}")
            group = _parse_group_from_filename(fname)
            data.append({"orig_name": fname, "score": score, "group": group})
    return data


def _anonymized_name(idx: int) -> str:
    return f"img_{idx:07d}.jpg"


def _split_data(items: list[dict]) -> tuple[list[dict], list[dict]]:
    # Deterministic 80/20 split within each group to preserve group balance
    rnd = random.Random(RANDOM_SEED)
    grouped: dict[str, list[dict]] = defaultdict(list)
    for it in items:
        grouped[it["group"]].append(it)

    train, test = [], []
    for g, lst in grouped.items():
        # Deterministic shuffle
        lst_sorted = sorted(lst, key=lambda x: x["orig_name"])  # sort for determinism
        rnd.shuffle(lst_sorted)
        n = len(lst_sorted)
        n_test = int(round(n * TEST_FRACTION))
        n_test = max(1, min(n - 1, n_test))  # ensure both non-empty
        test.extend(lst_sorted[:n_test])
        train.extend(lst_sorted[n_test:])
    return train, test


def _write_csv(path: Path, header: list[str], rows: list[list[str]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(header)
        w.writerows(rows)


def _link_or_copy(src: Path, dst: Path) -> None:
    dst.parent.mkdir(parents=True, exist_ok=True)
    try:
        if dst.exists():
            dst.unlink()
        os.link(src, dst)  # hard link for speed if same filesystem
    except Exception:
        shutil.copy2(src, dst)


def prepare(raw: Path, public: Path, private: Path):
    """
    Prepare competition files.

    Args:
        raw: absolute path to the raw/ directory containing original data files
             expected structure: raw/Images/Images/*.jpg and raw/labels.txt
        public: absolute path to the public/ directory to create
        private: absolute path to the private/ directory to create
    """
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Please provide absolute paths."

    # Input paths
    raw_img_dir = raw / "Images" / "Images"
    labels_path = raw / "labels.txt"

    # Pre-conditions
    assert raw.exists() and raw.is_dir(), f"Raw dir missing: {raw}"
    assert raw_img_dir.exists() and raw_img_dir.is_dir(), f"Raw image directory missing: {raw_img_dir}"
    assert labels_path.exists() and labels_path.is_file(), f"labels.txt missing: {labels_path}"

    # Reset output dirs
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Load labels
    items = _load_labels(labels_path)
    assert len(items) == 5500, f"Expected 5500 labeled items, got {len(items)}"

    # Ensure all image files exist
    missing = [it["orig_name"] for it in items if not (raw_img_dir / it["orig_name"]).is_file()]
    assert not missing, f"Missing {len(missing)} image files, examples: {missing[:5]}"

    # Split
    train_items, test_items = _split_data(items)

    # Deterministic anonymized mapping over all items to prevent leakage
    all_items_sorted = sorted(items, key=lambda x: x["orig_name"])
    mapping: dict[str, str] = {it["orig_name"]: _anonymized_name(i) for i, it in enumerate(all_items_sorted, start=1)}

    # Copy files into train/test with anonymized names
    public_train_dir = public / "train"
    public_test_dir = public / "test"
    for it in train_items:
        src = raw_img_dir / it["orig_name"]
        dst = public_train_dir / mapping[it["orig_name"]]
        _link_or_copy(src, dst)
    for it in test_items:
        src = raw_img_dir / it["orig_name"]
        dst = public_test_dir / mapping[it["orig_name"]]
        _link_or_copy(src, dst)

    # Sanity checks
    assert sum(1 for _ in public_train_dir.iterdir()) == len(train_items), "Mismatch in number of train images after copy"
    assert sum(1 for _ in public_test_dir.iterdir()) == len(test_items), "Mismatch in number of test images after copy"

    # Write train.csv and test.csv (public)
    train_rows = []
    for it in train_items:
        img_id = mapping[it["orig_name"]]
        for leak in ["AF", "AM", "CF", "CM"]:
            assert leak not in img_id, f"Leakage via filename: {img_id}"
        train_rows.append([img_id, f"{it['score']:.6f}"])
    _write_csv(public / "train.csv", ["image_id", "beauty"], train_rows)

    test_rows = []
    for it in test_items:
        img_id = mapping[it["orig_name"]]
        for leak in ["AF", "AM", "CF", "CM"]:
            assert leak not in img_id, f"Leakage via filename: {img_id}"
        test_rows.append([img_id])
    _write_csv(public / "test.csv", ["image_id"], test_rows)

    # Write hidden answers (private)
    ans_rows = []
    for it in test_items:
        img_id = mapping[it["orig_name"]]
        ans_rows.append([img_id, f"{it['score']:.6f}", it["group"]])
    _write_csv(private / "test_answer.csv", ["image_id", "beauty", "group"], ans_rows)

    # Sample submission in public with deterministic values in [1,5]
    rnd = random.Random(RANDOM_SEED)
    sample_rows = []
    for it in test_items:
        img_id = mapping[it["orig_name"]]
        pred = 1.0 + 4.0 * rnd.random()
        sample_rows.append([img_id, f"{pred:.6f}"])
    _write_csv(public / "sample_submission.csv", ["image_id", "beauty"], sample_rows)

    # Copy description.txt to public/
    repo_desc = Path(__file__).resolve().parent / "description.txt"
    if repo_desc.exists():
        shutil.copy2(repo_desc, public / "description.txt")

    # Post-creation checks
    n_total = len(train_items) + len(test_items)
    assert n_total == 5500, f"Total count mismatch: {n_total}"

    # Ensure no slashes in CSV image_ids
    for path in [public / "train.csv", public / "test.csv", private / "test_answer.csv", public / "sample_submission.csv"]:
        with path.open("r", encoding="utf-8") as f:
            text = f.read()
            assert "/" not in text, f"Found a '/' in CSV file {path} (should not include paths)"

    # Groups coverage: all groups present in both train and test
    train_groups = Counter([it["group"] for it in train_items])
    test_groups = Counter([it["group"] for it in test_items])
    assert set(train_groups.keys()) == {"AF", "AM", "CF", "CM"}, f"Train groups missing: {train_groups}"
    assert set(test_groups.keys()) == {"AF", "AM", "CF", "CM"}, f"Test groups missing: {test_groups}"

    # Ensure correspondence between public/test.csv and private/test_answer.csv
    with (public / "test.csv").open("r", encoding="utf-8") as f1, (private / "test_answer.csv").open("r", encoding="utf-8") as f2:
        r1 = list(csv.DictReader(f1))
        r2 = list(csv.DictReader(f2))
    ids1 = [r["image_id"] for r in r1]
    ids2 = [r["image_id"] for r in r2]
    assert ids1 == ids2, "public/test.csv and private/test_answer.csv image_id rows must be in the same order"

    # Ensure sample_submission aligns with test.csv
    with (public / "sample_submission.csv").open("r", encoding="utf-8") as f:
        r3 = list(csv.DictReader(f))
    ids3 = [r["image_id"] for r in r3]
    assert ids1 == ids3, "public/sample_submission.csv image_id rows must match public/test.csv in order"
