import os
import re
import csv
import shutil
import random
from pathlib import Path
from typing import List, Dict, Optional, Tuple

# Deterministic behavior
SEED = 2024
random.seed(SEED)


def _find_pixel_root(raw: Path) -> Path:
    """Locate the Pixel/ directory within the provided raw directory.
    Tries the canonical path first, otherwise searches recursively.
    """
    canon = raw / "www.acmeai.tech Dataset - BMGF-LivestockWeight-CV" / "Pixel"
    if canon.exists():
        return canon
    # Fallback: search
    for root, dirs, files in os.walk(raw):
        if os.path.basename(root) == "Pixel":
            return Path(root)
    raise FileNotFoundError("Could not locate 'Pixel' data directory under raw/")


IMG_EXTS = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}


def _parse_filename(stem: str) -> Optional[Tuple[str, float]]:
    """Parse an original filename stem to derive an animal identity and weight.

    Strategy (robust across B2/B3/B4 naming):
      - Split by '_', find the first occurrence of view token 's' or 'r'.
      - Animal identity is the first token before the view token (conservative, avoids leakage).
      - Weight is the numeric token immediately after the view token.
    Returns (animal_id, weight_kg) or None if not parseable.
    """
    parts = stem.split("_")
    view_idx = None
    for i, p in enumerate(parts):
        if p in ("s", "r"):
            view_idx = i
            break
    if view_idx is None:
        return None
    if view_idx < 1:
        return None
    animal_id = parts[0]

    if view_idx + 1 >= len(parts):
        return None
    weight_token = parts[view_idx + 1]
    try:
        weight = float(weight_token)
    except Exception:
        cleaned = re.sub(r"[^0-9.+-]", "", weight_token)
        if cleaned == "":
            return None
        try:
            weight = float(cleaned)
        except Exception:
            return None
    return animal_id, float(weight)


def _collect_entries(pixel_root: Path) -> List[Dict]:
    entries: List[Dict] = []
    for root, dirs, files in os.walk(pixel_root):
        # Use only image folders named 'images'
        if os.path.basename(root) != "images":
            continue
        ann_dir = Path(root).parent / "annotations"
        for fname in sorted(files):
            fpath = Path(root) / fname
            if fpath.suffix not in IMG_EXTS:
                continue
            parsed = _parse_filename(fpath.stem)
            if parsed is None:
                continue
            animal_id, weight = parsed
            # basic sanity on weight
            if not (0.0 < weight < 1000.0):
                continue
            mask_path = None
            if ann_dir.exists():
                candidate = ann_dir / (fname + "___fuse.png")
                if candidate.exists():
                    mask_path = candidate
            entries.append({
                "src_img": fpath,
                "src_mask": mask_path,
                "animal_id": animal_id,
                "weight_kg": float(weight),
            })
    if not entries:
        raise RuntimeError("No valid images discovered under Pixel/ images directories.")
    return entries


def _split_by_identity(entries: List[Dict], test_ratio: float = 0.2) -> Tuple[List[Dict], List[Dict]]:
    by_id: Dict[str, List[Dict]] = {}
    for e in entries:
        by_id.setdefault(e["animal_id"], []).append(e)
    ids = sorted(by_id.keys())
    rnd = random.Random(SEED)
    rnd.shuffle(ids)
    n_test = max(1, int(round(len(ids) * test_ratio)))
    test_ids = set(ids[:n_test])
    train_ids = set(ids[n_test:])

    train, test = [], []
    for aid in train_ids:
        train.extend(by_id[aid])
    for aid in test_ids:
        test.extend(by_id[aid])

    assert train and test, "Empty train or test split"
    assert set([e["animal_id"] for e in train]).isdisjoint(
        set([e["animal_id"] for e in test])
    ), "Identity leakage between train and test"
    return train, test


def _safe_reset_dir(d: Path):
    if d.exists():
        shutil.rmtree(d)
    d.mkdir(parents=True, exist_ok=True)


def prepare(raw: Path, public: Path, private: Path):
    """Prepare dataset artifacts from raw/ into public/ and private/.

    - public/: train.csv, test.csv, sample_submission.csv, train_images/, test_images/ and optional train_masks/, test_masks/, plus description.txt
    - private/: test_answer.csv (ground-truth for test)
    - Deterministic identity-level split
    """
    raw = Path(raw).resolve()
    public = Path(public).resolve()
    private = Path(private).resolve()

    pixel_root = _find_pixel_root(raw)

    # Prepare output directories
    _safe_reset_dir(public)
    _safe_reset_dir(private)

    train_img_dir = public / "train_images"
    test_img_dir = public / "test_images"
    train_mask_dir = public / "train_masks"
    test_mask_dir = public / "test_masks"
    for d in [train_img_dir, test_img_dir, train_mask_dir, test_mask_dir]:
        d.mkdir(parents=True, exist_ok=True)

    # Collect and split
    entries = _collect_entries(pixel_root)
    train_raw, test_raw = _split_by_identity(entries, test_ratio=0.2)

    # Copy and anonymize filenames
    def copy_block(block: List[Dict], dst_img: Path, dst_mask: Path, start_idx: int) -> Tuple[List[Dict], int]:
        out: List[Dict] = []
        idx = start_idx
        for e in block:
            new_name = f"img_{idx:06d}.jpg"
            shutil.copy2(e["src_img"], dst_img / new_name)
            out_e = {
                "filename": new_name,
                "weight_kg": e["weight_kg"],
                "mask_filename": None,
                "animal_id": e["animal_id"],
            }
            if e["src_mask"] is not None and Path(e["src_mask"]).exists():
                mask_new = f"img_{idx:06d}_mask.png"
                shutil.copy2(e["src_mask"], dst_mask / mask_new)
                out_e["mask_filename"] = mask_new
            out.append(out_e)
            idx += 1
        return out, idx

    train_entries, next_idx = copy_block(train_raw, train_img_dir, train_mask_dir, 0)
    test_entries, _ = copy_block(test_raw, test_img_dir, test_mask_dir, next_idx)

    # Write CSVs
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_sub_csv = public / "sample_submission.csv"

    with open(train_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["filename", "weight_kg"])
        for e in train_entries:
            w.writerow([e["filename"], f"{e['weight_kg']:.6f}"])

    with open(test_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["filename"])
        for e in test_entries:
            w.writerow([e["filename"]])

    with open(test_answer_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["filename", "weight_kg"])
        for e in test_entries:
            w.writerow([e["filename"], f"{e['weight_kg']:.6f}"])

    # Sample submission: constant mean of train
    train_weights = [e["weight_kg"] for e in train_entries]
    mean_w = sum(train_weights) / max(1, len(train_weights))

    def sample_pred(_: str) -> float:
        return float(max(30.0, min(700.0, mean_w)))

    with open(sample_sub_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["filename", "weight_kg"])
        for e in test_entries:
            w.writerow([e["filename"], f"{sample_pred(e['filename']):.6f}"])

    # Copy description.txt into public
    top_desc = (Path(__file__).resolve().parent / "description.txt").resolve()
    if top_desc.exists():
        shutil.copy2(top_desc, public / "description.txt")

    # Checks (no leakage, consistency)
    train_names = [e["filename"] for e in train_entries]
    test_names = [e["filename"] for e in test_entries]
    assert set(train_names).isdisjoint(set(test_names)), "Train/Test filenames overlap"
    assert len(train_names) == len(set(train_names)), "Duplicate filenames in train"
    assert len(test_names) == len(set(test_names)), "Duplicate filenames in test"

    # Ensure files exist
    for n in train_names:
        assert (train_img_dir / n).exists(), f"Missing train image {n}"
    for n in test_names:
        assert (test_img_dir / n).exists(), f"Missing test image {n}"

    # CSV format expectations
    with open(train_csv, "r", encoding="utf-8") as f:
        rows = list(csv.DictReader(f))
    assert rows and set(rows[0].keys()) == {"filename", "weight_kg"}

    with open(test_csv, "r", encoding="utf-8") as f:
        rows = list(csv.DictReader(f))
    assert rows and list(rows[0].keys()) == ["filename"]

    with open(test_answer_csv, "r", encoding="utf-8") as f:
        ans_rows = list(csv.DictReader(f))
    assert ans_rows and set(ans_rows[0].keys()) == {"filename", "weight_kg"}

    # Mask presence check when recorded
    for e in train_entries:
        if e.get("mask_filename"):
            assert (train_mask_dir / e["mask_filename"]).exists()
    for e in test_entries:
        if e.get("mask_filename"):
            assert (test_mask_dir / e["mask_filename"]).exists()

    # Ensure private labels are not in public
    assert not (public / "test_answer.csv").exists(), "test_answer.csv must be private only"
