from pathlib import Path
import shutil
from typing import Dict, List
import random
import csv

from tqdm.auto import tqdm


def _is_off_file_valid(path: Path) -> bool:
    try:
        with path.open("r", encoding="utf-8", errors="ignore") as f:
            first = f.readline().strip()
            return first in ("OFF", "NOFF", "COFF") or first.startswith("OFF ")
    except Exception:
        return False


def _generate_id(idx: int) -> str:
    return f"mesh_{idx:06d}.off"


def _hardlink_or_copy(src: Path, dst: Path):
    if dst.exists():
        return
    try:
        # Attempt hardlink for speed/storage
        os_link = getattr(shutil, "_fastcopy_sendfile", None)
        src.link_to  # type: ignore[attr-defined]
    except Exception:
        pass
    try:
        dst.hardlink_to(src)  # type: ignore[attr-defined]
    except Exception:
        shutil.copy2(src, dst)


def prepare(raw: Path, public: Path, private: Path):
    """
    Prepare the ModelNet40 competition assets.

    Inputs:
    - raw: absolute path to raw/ directory that contains ModelNet40 folder as downloaded
    - public: absolute path to public/ output directory
    - private: absolute path to private/ output directory

    Outputs in public/:
    - train/ and test/ OFF files with anonymized filenames
    - train.csv with columns [id, label]
    - test.csv with column [id]
    - sample_submission.csv with columns [id, label]
    - description.txt copied from root description (if exists) or generated

    Outputs in private/:
    - test_answer.csv with columns [id, label]

    Deterministic: ordering by label name, then by split, then by file name.
    """
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Please use absolute paths"

    src_root = raw / "ModelNet40"
    assert src_root.exists() and src_root.is_dir(), f"Missing source data at {src_root}"

    # Clean and recreate outputs
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    train_dir = public / "train"
    test_dir = public / "test"
    train_dir.mkdir(parents=True, exist_ok=True)
    test_dir.mkdir(parents=True, exist_ok=True)

    # Gather files deterministically
    labels = [d.name for d in sorted(src_root.iterdir()) if d.is_dir()]
    data: Dict[str, Dict[str, List[Path]]] = {}
    for lbl in labels:
        data[lbl] = {"train": [], "test": []}
        for split in ("train", "test"):
            split_dir = src_root / lbl / split
            if not split_dir.exists():
                continue
            files = sorted([p for p in split_dir.iterdir() if p.suffix.lower() == ".off"])
            data[lbl][split] = files

    # Map ids and copy files
    id_to_label_train: Dict[str, str] = {}
    id_to_label_test: Dict[str, str] = {}
    current_idx = 1

    for lbl in labels:
        for split in ("train", "test"):
            for src_path in data[lbl][split]:
                new_id = _generate_id(current_idx)
                current_idx += 1
                dst_path = (train_dir if split == "train" else test_dir) / new_id
                shutil.copy2(src_path, dst_path)
                # Sanity on OFF header for a few files
                # (full validation later below on a sample)
                if split == "train":
                    id_to_label_train[new_id] = lbl
                else:
                    id_to_label_test[new_id] = lbl

    # Write CSVs
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_submission_csv = public / "sample_submission.csv"

    with train_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for k in sorted(id_to_label_train.keys()):
            w.writerow([k, id_to_label_train[k]])

    with test_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id"])
        for k in sorted(id_to_label_test.keys()):
            w.writerow([k])

    with test_answer_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for k in sorted(id_to_label_test.keys()):
            w.writerow([k, id_to_label_test[k]])

    # Deterministic sample submission with a fixed random seed
    rng = random.Random(0)
    train_label_set = sorted(set(id_to_label_train.values()))
    with sample_submission_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for k in sorted(id_to_label_test.keys()):
            w.writerow([k, rng.choice(train_label_set)])

    # Copy description.txt into public if present at root
    root_description = raw.parent / "description.txt"
    if root_description.exists():
        shutil.copy2(root_description, public / "description.txt")
    else:
        # generate a minimal description
        with (public / "description.txt").open("w", encoding="utf-8") as f:
            f.write(
                "Title: 3D Mesh Classification on ModelNet40\n\n"
                "Files in this folder:\n"
                "- train.csv: id,label pairs for training OFF files in public/train/\n"
                "- test.csv: id list for testing OFF files in public/test/\n"
                "- sample_submission.csv: example submission format (id,label)\n"
                "- train/: OFF meshes for training\n"
                "- test/: OFF meshes for testing\n"
            )

    # Checks
    # 1) Directories exist
    assert train_dir.exists() and test_dir.exists(), "train/ and test/ must exist in public/"

    # 2) Counts match CSVs
    n_train_files = len([p for p in train_dir.iterdir() if p.suffix.lower() == ".off"])
    n_test_files = len([p for p in test_dir.iterdir() if p.suffix.lower() == ".off"])

    with train_csv.open("r", encoding="utf-8") as f:
        n_train_csv = sum(1 for _ in f) - 1
    with test_csv.open("r", encoding="utf-8") as f:
        n_test_csv = sum(1 for _ in f) - 1
    assert n_train_files == n_train_csv == len(id_to_label_train), "Mismatch in train counts"
    assert n_test_files == n_test_csv == len(id_to_label_test), "Mismatch in test counts"

    # 3) OFF header sanity check on a small sample
    sample_check = 25
    checked = 0
    for d in (train_dir, test_dir):
        for name in sorted(p.name for p in d.iterdir() if p.suffix.lower() == ".off"):
            if _is_off_file_valid(d / name):
                checked += 1
                if checked >= sample_check:
                    break
    assert checked > 0, "No OFF files validated"

    # 4) No label leakage in filenames
    lower_labels = [lbl.lower() for lbl in labels]
    for d in (train_dir, test_dir):
        for name in (p.name for p in d.iterdir() if p.is_file()):
            lowname = name.lower()
            assert all(lbl not in lowname for lbl in lower_labels), f"Label leakage in filename: {name}"

    # 5) sample_submission ids match test ids and labels are valid
    with sample_submission_csv.open("r", encoding="utf-8") as f:
        r = csv.DictReader(f)
        sub_ids = []
        sub_labels = []
        for row in r:
            sub_ids.append(row["id"])
            sub_labels.append(row["label"])
    test_ids_set = set(id_to_label_test.keys())
    assert set(sub_ids) == test_ids_set, "sample_submission.csv does not match test ids"
    assert set(sub_labels).issubset(set(train_label_set)), "sample_submission has invalid labels"

    # 6) Public CSV schema check
    assert (public / "description.txt").exists(), "description.txt must be in public/"
