from pathlib import Path
import json
import csv
import shutil
import random
from collections import Counter
from typing import List, Tuple

# Deterministic behavior for any randomness in splitting or sample submission
random.seed(137)

LABELS = [
    "NotHate",
    "Racist",
    "Sexist",
    "Homophobe",
    "Religion",
    "OtherHate",
]
# Tie-breaker priority from more harmful to less harmful, then NotHate
TIE_PRIORITY = ["Racist", "Sexist", "Homophobe", "Religion", "OtherHate", "NotHate"]


def _safe_mkdir(path: Path):
    path.mkdir(parents=True, exist_ok=True)


def _read_ids(path: Path) -> List[str]:
    ids: List[str] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if s:
                ids.append(s.split(",")[0])
    return ids


def _majority_vote(labels_str_list: List[str]) -> str:
    cnt = Counter(labels_str_list)
    max_count = max(cnt.values())
    candidates = [k for k, v in cnt.items() if v == max_count]
    if len(candidates) == 1:
        return candidates[0]
    for lab in TIE_PRIORITY:
        if lab in candidates:
            return lab
    return sorted(candidates)[0]


def _sanitize_text(s: str) -> str:
    if s is None:
        return ""
    return " ".join(str(s).split())


def _ensure_image_path(img_dir: Path, base_id: str) -> Path:
    candidates = [img_dir / f"{base_id}{ext}" for ext in [".jpg", ".png", ".jpeg", ".JPG", ".PNG", ".JPEG"]]
    for c in candidates:
        if c.exists():
            return c
    raise FileNotFoundError(f"Image file for id {base_id} not found in {img_dir}")


def _get_ocr_text(ocr_json_path: Path) -> str:
    if not ocr_json_path.exists():
        return ""
    try:
        with ocr_json_path.open("r", encoding="utf-8") as f:
            data = json.load(f)
        if isinstance(data, dict):
            for key in ["img_text", "text", "ocr", "ocr_text"]:
                if key in data and isinstance(data[key], str):
                    return _sanitize_text(data[key])
            texts = [_sanitize_text(v) for v in data.values() if isinstance(v, str)]
            if texts:
                return " ".join(texts)
        elif isinstance(data, list):
            texts = [_sanitize_text(x) for x in data if isinstance(x, str)]
            return " ".join(texts)
    except Exception:
        return ""
    return ""


def _write_csv(path: Path, fieldnames: List[str], rows: List[dict]):
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for r in rows:
            clean = {k: (" ".join(str(v).split()) if isinstance(v, str) else v) for k, v in r.items()}
            writer.writerow(clean)


def prepare(raw: Path, public: Path, private: Path):
    # Validate inputs are absolute paths as per requirement
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Paths must be absolute"

    # raw structure expectations
    img_dir = raw / "img_resized"
    ocr_dir = raw / "img_txt"
    gt_json = raw / "MMHS150K_GT.json"
    splits_dir = raw / "splits"

    assert img_dir.is_dir(), f"Image dir not found: {img_dir}"
    assert gt_json.is_file(), f"Ground truth JSON not found: {gt_json}"
    assert splits_dir.is_dir(), f"Splits dir not found: {splits_dir}"

    train_ids_path = splits_dir / "train_ids.txt"
    val_ids_path = splits_dir / "val_ids.txt"
    test_ids_path = splits_dir / "test_ids.txt"

    for p in [train_ids_path, val_ids_path, test_ids_path]:
        assert p.is_file(), f"Missing split file: {p}"

    # Load GT json
    with gt_json.open("r", encoding="utf-8") as f:
        gt = json.load(f)
    assert isinstance(gt, dict) and len(gt) > 0

    # Read splits
    tr_ids = _read_ids(train_ids_path)
    va_ids = _read_ids(val_ids_path)
    te_ids = _read_ids(test_ids_path)

    # Combine train+val for training
    build_train_ids = tr_ids + va_ids

    # Uniqueness and disjointness checks
    def _assert_unique(name: str, lst: List[str]):
        assert len(lst) == len(set(lst)), f"Duplicate ids found in {name}"

    _assert_unique("train_ids.txt", tr_ids)
    _assert_unique("val_ids.txt", va_ids)
    _assert_unique("test_ids.txt", te_ids)

    assert set(build_train_ids).isdisjoint(set(te_ids)), "Found overlap between train/val and test"

    # Filter ids to those present in GT
    build_train_ids = [i for i in build_train_ids if i in gt]
    te_ids = [i for i in te_ids if i in gt]

    # Resolve one label per id
    map_idx = {0: "NotHate", 1: "Racist", 2: "Sexist", 3: "Homophobe", 4: "Religion", 5: "OtherHate"}

    def _resolve(id_: str) -> Tuple[str, str]:
        meta = gt[id_]
        lbls = meta.get("labels_str") or []
        if not isinstance(lbls, list) or len(lbls) == 0:
            num = meta.get("labels") or []
            lbls = [map_idx.get(int(x), "NotHate") for x in num if isinstance(x, (int, float, str))]
            if not lbls:
                lbls = ["NotHate"]
        final_label = _majority_vote(lbls)
        assert final_label in LABELS, f"Resolved label {final_label} not in allowed set for id {id_}"
        tweet_text = _sanitize_text(meta.get("tweet_text", ""))
        return final_label, tweet_text

    # Collect used items
    used: List[tuple] = []
    for id_list, is_train in [(build_train_ids, True), (te_ids, False)]:
        for id_ in id_list:
            try:
                img_path = _ensure_image_path(img_dir, id_)
            except FileNotFoundError:
                # skip if no image
                continue
            label, tweet_text = _resolve(id_)
            ocr_src = ocr_dir / f"{id_}.json"
            ocr_text = _get_ocr_text(ocr_src)
            used.append((id_, img_path, ocr_src, label, tweet_text, ocr_text, is_train))

    assert used, "No samples found to build dataset."

    # Sort by original id for deterministic anonymization
    used.sort(key=lambda x: x[0])

    # Build mapping and outputs
    id_map = {}
    for idx, (orig_id, img_path, ocr_src, label, tweet_text, ocr_text, is_train) in enumerate(used, start=1):
        anon_id = f"MMHS_{idx:06d}"
        id_map[orig_id] = {
            "anon_id": anon_id,
            "img_src": img_path,
            "ocr_src": ocr_src if ocr_src.exists() else None,
            "label": label,
            "tweet_text": tweet_text,
            "ocr_text": ocr_text,
            "is_train": is_train,
        }

    # Create public/private structure
    _safe_mkdir(public)
    _safe_mkdir(private)

    train_images_dir = public / "train_images"
    test_images_dir = public / "test_images"
    train_text_dir = public / "train_text"
    test_text_dir = public / "test_text"

    _safe_mkdir(train_images_dir)
    _safe_mkdir(test_images_dir)
    _safe_mkdir(train_text_dir)
    _safe_mkdir(test_text_dir)

    # Helper: copy image and write standardized OCR json
    def _copy(src: Path, dst: Path):
        shutil.copy2(src, dst)

    def _write_ocr_json(dst_path: Path, text_value: str):
        payload = {"img_text": _sanitize_text(text_value)}
        tmp = dst_path.with_suffix(dst_path.suffix + ".tmp")
        with tmp.open("w", encoding="utf-8") as f:
            json.dump(payload, f, ensure_ascii=False)
        tmp.replace(dst_path)

    train_rows = []
    test_rows = []
    test_answers = []

    for orig_id, info in id_map.items():
        anon_id = info["anon_id"]
        image_name = f"{anon_id}.jpg"
        text_name = f"{anon_id}.json"

        if info["is_train"]:
            dst_img_dir = train_images_dir
            dst_txt_dir = train_text_dir
        else:
            dst_img_dir = test_images_dir
            dst_txt_dir = test_text_dir

        dst_img = dst_img_dir / image_name
        _copy(info["img_src"], dst_img)
        dst_txt = dst_txt_dir / text_name
        _write_ocr_json(dst_txt, info["ocr_text"])

        if info["is_train"]:
            train_rows.append({
                "id": anon_id,
                "image": image_name,
                "tweet_text": info["tweet_text"],
                "ocr_text": info["ocr_text"],
                "label": info["label"],
            })
        else:
            test_rows.append({
                "id": anon_id,
                "image": image_name,
                "tweet_text": info["tweet_text"],
                "ocr_text": info["ocr_text"],
            })
            test_answers.append({"id": anon_id, "label": info["label"]})

    # Write CSVs in required locations
    _write_csv(public / "train.csv", ["id", "image", "tweet_text", "ocr_text", "label"], train_rows)
    _write_csv(public / "test.csv", ["id", "image", "tweet_text", "ocr_text"], test_rows)
    _write_csv(private / "test_answer.csv", ["id", "label"], test_answers)

    # Sample submission in public
    rnd_rows = []
    for r in test_rows:
        rnd_rows.append({"id": r["id"], "label": random.choice(LABELS)})
    _write_csv(public / "sample_submission.csv", ["id", "label"], rnd_rows)

    # Copy description.txt into public if available at root
    root_desc = Path(__file__).resolve().parent / "description.txt"
    if root_desc.exists():
        shutil.copy2(root_desc, public / "description.txt")

    # Sanity checks
    # Files existence
    assert (public / "train.csv").is_file(), "Missing public/train.csv"
    assert (public / "test.csv").is_file(), "Missing public/test.csv"
    assert (public / "sample_submission.csv").is_file(), "Missing public/sample_submission.csv"
    assert (private / "test_answer.csv").is_file(), "Missing private/test_answer.csv"

    # Folders non-empty
    for d in [train_images_dir, test_images_dir, train_text_dir, test_text_dir]:
        assert d.is_dir() and any(d.iterdir()), f"Empty or missing dir: {d}"

    # Id sets
    train_ids_set = {r["id"] for r in train_rows}
    test_ids_set = {r["id"] for r in test_rows}
    assert train_ids_set.isdisjoint(test_ids_set), "Train/Test id overlap detected"

    ans_ids = {r["id"] for r in test_answers}
    assert ans_ids == test_ids_set, "Mismatch between test ids and test answers"

    # Label sets validation and coverage
    train_labels = {r["label"] for r in train_rows}
    test_labels = {r["label"] for r in test_answers}
    assert train_labels.issubset(set(LABELS)), "Unexpected labels in training set"
    assert test_labels.issubset(set(LABELS)), "Unexpected labels in test set"
    missing_in_train = test_labels - train_labels
    assert not missing_in_train, f"Test labels missing in train: {missing_in_train}"

    # Ensure CSV references assets that exist
    for r in train_rows:
        p_img = train_images_dir / r["image"]
        p_txt = train_text_dir / f"{r['id']}.json"
        assert p_img.exists(), f"Missing image referenced in train.csv: {p_img}"
        assert p_txt.exists(), f"Missing ocr json for train id: {p_txt}"

    for r in test_rows:
        p_img = test_images_dir / r["image"]
        p_txt = test_text_dir / f"{r['id']}.json"
        assert p_img.exists(), f"Missing image referenced in test.csv: {p_img}"
        assert p_txt.exists(), f"Missing ocr json for test id: {p_txt}"

    # Ensure no data paths in CSVs
    for r in train_rows:
        assert "/" not in r["image"] and "\\" not in r["image"], "Image path leakage in train.csv"
    for r in test_rows:
        assert "/" not in r["image"] and "\\" not in r["image"], "Image path leakage in test.csv"

    # Ensure no label leakage in ids and image names
    for r in train_rows:
        name = (r["id"] + " " + r["image"]).lower()
        for lab in LABELS:
            assert lab.lower() not in name, "Potential label leakage in id/image name"
