from pathlib import Path
import os
import csv
import hashlib
import random
from typing import Dict, List, Tuple, Optional

# Public interface required by the task
# Do not add a main() here; tests will call prepare(raw, public, private) directly.

EMOTIONS = ["anger", "disgust", "sadness", "joy", "neutral", "surprise", "fear"]


def _safe_makedirs(path: Path):
    path.mkdir(parents=True, exist_ok=True)


def _read_csv_rows(path: Path) -> List[Dict[str, str]]:
    assert path.is_absolute(), f"Path must be absolute: {path}"
    with path.open("r", encoding="utf-8", errors="replace", newline="") as f:
        reader = csv.DictReader(f)
        rows = [row for row in reader]
        # Sanity check for expected columns in MELD-RAW CSVs
        required_cols = {
            "Utterance",
            "Speaker",
            "Emotion",
            "Sentiment",
            "Dialogue_ID",
            "Utterance_ID",
            "Season",
            "Episode",
            "StartTime",
            "EndTime",
        }
        missing = required_cols.difference(reader.fieldnames or [])
        assert not missing, f"Missing required columns in {path}: {missing}"
        return rows


def _make_src_filename(row: Dict[str, str]) -> str:
    d = row["Dialogue_ID"].strip()
    u = row["Utterance_ID"].strip()
    return f"dia{d}_utt{u}.mp4"


def _resolve_video_path(split: str, row: Dict[str, str], src_root: Path) -> Optional[Path]:
    fname = _make_src_filename(row)
    if split == "test":
        path = src_root / "test" / "output_repeated_splits_test" / fname
        return path if path.exists() else None
    else:
        # train/dev merged into train
        p1 = src_root / "train" / "train_splits" / fname
        p2 = src_root / "dev" / "dev_splits_complete" / fname
        if p1.exists():
            return p1
        if p2.exists():
            return p2
        return None


def _src_key(row: Dict[str, str]) -> str:
    # Deterministic key to detect exact duplicate annotations across splits
    return "|".join(
        [
            row["Dialogue_ID"].strip(),
            row["Utterance_ID"].strip(),
            row["Season"].strip(),
            row["Episode"].strip(),
            row["StartTime"].strip(),
            row["EndTime"].strip(),
            row["Emotion"].strip().lower(),
        ]
    )


def _anon_id(row: Dict[str, str]) -> str:
    # Deterministic anonymized id based on multiple fields
    key = (
        f"{row['Dialogue_ID'].strip()}|{row['Utterance_ID'].strip()}|{row['Season'].strip()}|{row['Episode'].strip()}|{row['StartTime'].strip()}|{row['EndTime'].strip()}".encode(
            "utf-8"
        )
    )
    h = hashlib.sha1(key).hexdigest()[:12]
    return f"u{h}"


def _symlink_if_needed(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.is_symlink() or dst.exists():
        try:
            if dst.is_symlink():
                cur = Path(os.readlink(dst))
                if cur == src:
                    return
        except OSError:
            pass
        try:
            dst.unlink()
        except FileNotFoundError:
            pass
    os.symlink(src.as_posix(), dst.as_posix())


def _write_csv(path: Path, fieldnames: List[str], rows: List[Dict[str, str]]):
    assert path.is_absolute()
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for r in rows:
            writer.writerow(r)


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete preparation process for MERC (MELD) dataset.

    Inputs (absolute paths):
    - raw: path to the raw/ directory that contains MELD-RAW/MELD.Raw
    - public: path to populate with train.csv, test.csv, sample_submission.csv, media/, and description.txt
    - private: path to populate with test_answer.csv only
    """
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), (
        "All paths must be absolute"
    )

    src_root = raw / "MELD-RAW" / "MELD.Raw"
    train_csv = src_root / "train" / "train_sent_emo.csv"
    dev_csv = src_root / "dev_sent_emo.csv"
    test_csv = src_root / "test_sent_emo.csv"

    # Output locations
    out_train_csv = public / "train.csv"
    out_test_csv = public / "test.csv"
    out_test_ans = private / "test_answer.csv"
    out_sample = public / "sample_submission.csv"
    media_train_dir = public / "media" / "train"
    media_test_dir = public / "media" / "test"

    # Prepare directories
    _safe_makedirs(public)
    _safe_makedirs(private)
    _safe_makedirs(media_train_dir)
    _safe_makedirs(media_test_dir)

    # Read CSVs
    train_rows = _read_csv_rows(train_csv)
    dev_rows = _read_csv_rows(dev_csv)
    test_rows = _read_csv_rows(test_csv)

    # Combine train + dev with de-duplication (prefer train entries)
    combined_train_src_keys = set()
    combined_rows: List[Dict[str, str]] = []

    def _add_rows(rows: List[Dict[str, str]]):
        for row in rows:
            key = _src_key(row)
            if key in combined_train_src_keys:
                continue
            combined_train_src_keys.add(key)
            combined_rows.append(row)

    _add_rows(train_rows)
    _add_rows(dev_rows)

    seen_ids = set()

    def _normalize_row(row: Dict[str, str], split: str) -> Tuple[Optional[Dict[str, str]], Optional[Path], Optional[Path]]:
        _id = _anon_id(row)
        if split != "test" and _id in seen_ids:
            # Add suffix if required for uniqueness
            suffix = hashlib.sha1(row["Utterance"].encode("utf-8")).hexdigest()[:6]
            _id = f"{_id}_{suffix}"
        video_src = _resolve_video_path(split, row, src_root)
        if video_src is None:
            return None, None, None
        video_rel = Path("media") / split / f"{_id}.mp4"
        video_dst = public / video_rel
        nrow = {
            "id": _id,
            "utterance": row["Utterance"],
            "speaker": row["Speaker"],
            "dialogue_id": row["Dialogue_ID"],
            "utterance_id": row["Utterance_ID"],
            "season": row["Season"],
            "episode": row["Episode"],
            "start_time": row["StartTime"],
            "end_time": row["EndTime"],
            "video": video_rel.as_posix(),
            "emotion": row["Emotion"].strip().lower(),
        }
        return nrow, video_src, video_dst

    # Process combined train
    out_train_rows: List[Dict[str, str]] = []
    for row in combined_rows:
        nrow, src_path, dst_path = _normalize_row(row, "train")
        if nrow is None:
            continue
        if nrow["id"] in seen_ids:
            raise AssertionError(f"Duplicate id in train: {nrow['id']}")
        seen_ids.add(nrow["id"])
        # Link media to public
        _symlink_if_needed(src_path, dst_path)  # type: ignore[arg-type]
        assert dst_path.exists() or dst_path.is_symlink(), f"Failed to create media link: {dst_path}"
        assert nrow["emotion"] in EMOTIONS, f"Invalid emotion: {nrow['emotion']}"
        out_train_rows.append(nrow)

    # Process test
    out_test_rows: List[Dict[str, str]] = []
    out_test_ans_rows: List[Dict[str, str]] = []
    seen_test_ids = set()
    for row in test_rows:
        nrow, src_path, dst_path = _normalize_row(row, "test")
        if nrow is None:
            continue
        # Link media
        _symlink_if_needed(src_path, dst_path)  # type: ignore[arg-type]
        assert dst_path.exists() or dst_path.is_symlink(), f"Failed to create media link: {dst_path}"
        # Save test.csv without emotion
        test_row = {k: v for k, v in nrow.items() if k != "emotion"}
        out_test_rows.append(test_row)
        # Save answers
        out_test_ans_rows.append({"id": nrow["id"], "emotion": nrow["emotion"]})
        # ID uniqueness
        assert nrow["id"] not in seen_ids, f"ID collision between train and test: {nrow['id']}"
        assert nrow["id"] not in seen_test_ids, f"Duplicate id in test: {nrow['id']}"
        seen_test_ids.add(nrow["id"])

    # Write CSVs
    train_fields = [
        "id",
        "utterance",
        "speaker",
        "dialogue_id",
        "utterance_id",
        "season",
        "episode",
        "start_time",
        "end_time",
        "video",
        "emotion",
    ]
    test_fields = [
        "id",
        "utterance",
        "speaker",
        "dialogue_id",
        "utterance_id",
        "season",
        "episode",
        "start_time",
        "end_time",
        "video",
    ]
    _write_csv(out_train_csv, train_fields, out_train_rows)
    _write_csv(out_test_csv, test_fields, out_test_rows)
    _write_csv(out_test_ans, ["id", "emotion"], out_test_ans_rows)

    # Create sample submission with random valid labels
    random.seed(1337)
    sample_rows = [{"id": r["id"], "emotion": random.choice(EMOTIONS)} for r in out_test_rows]
    _write_csv(out_sample, ["id", "emotion"], sample_rows)

    # Copy description.txt into public (without leaking answers)
    root_description = (public.parent if public.name else public) / "description.txt"
    if root_description.exists():
        (public / "description.txt").write_text(root_description.read_text(encoding="utf-8"), encoding="utf-8")

    # Checks
    # Counts
    assert out_train_csv.exists(), "public/train.csv should exist"
    assert out_test_csv.exists(), "public/test.csv should exist"
    assert out_test_ans.exists(), "private/test_answer.csv should exist"
    # ID alignment between test.csv and test_answer.csv
    with out_test_csv.open("r", encoding="utf-8") as f1, out_test_ans.open("r", encoding="utf-8") as f2:
        test_ids = [row["id"] for row in csv.DictReader(f1)]
        ans_ids = [row["id"] for row in csv.DictReader(f2)]
        assert set(test_ids) == set(ans_ids), "Mismatch between test.csv and test_answer.csv IDs"
        assert len(test_ids) == len(ans_ids), "ID counts must match"
    # Ensure all labels valid
    with out_test_ans.open("r", encoding="utf-8") as f:
        for r in csv.DictReader(f):
            assert r["emotion"].strip().lower() in EMOTIONS
    # Ensure media paths exist for a small sample to keep runtime reasonable
    sample_check = 50
    for i, row in enumerate(csv.DictReader(out_train_csv.open("r", encoding="utf-8"))):
        if i >= sample_check:
            break
        v = public / row["video"]
        assert v.exists() or v.is_symlink(), f"Missing media referenced in train.csv: {v}"
    for i, row in enumerate(csv.DictReader(out_test_csv.open("r", encoding="utf-8"))):
        if i >= sample_check:
            break
        v = public / row["video"]
        assert v.exists() or v.is_symlink(), f"Missing media referenced in test.csv: {v}"

    # Sanity: columns
    assert [*csv.DictReader(out_sample.open("r", encoding="utf-8")).fieldnames] == ["id", "emotion"], "sample_submission.csv must have exactly id,emotion"
