import json
import shutil
from pathlib import Path
from typing import Any, Dict, List

import pandas as pd


def _read_metadata(raw: Path) -> List[Dict[str, Any]]:
    lavdf_root = raw / "LAV-DF"
    meta_min = lavdf_root / "metadata.min.json"
    meta_full = lavdf_root / "metadata.json"

    meta_path = meta_min if meta_min.exists() else meta_full
    assert meta_path.exists(), f"Metadata JSON not found at: {meta_path}"
    with open(meta_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    assert isinstance(data, list), "Metadata must be a list of records"
    return data


def _safe_link_or_copy(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    try:
        if dst.exists():
            dst.unlink()
        # use hardlink for efficiency if possible
        dst.hardlink_to(src)
    except Exception:
        shutil.copy2(src, dst)


def prepare(raw: Path, public: Path, private: Path):
    """
    Prepare the LAV-DF dataset for competition use.

    Inputs:
      - raw: absolute path to the directory containing the original data (expects raw/LAV-DF/...)
      - public: absolute path to the directory for public files
      - private: absolute path to the directory for private files (test answers)

    Artifacts created:
      public/
        - train_videos/
        - test_videos/
        - train.csv
        - test.csv
        - sample_submission.csv
        - description.txt
      private/
        - test_answer.csv
    """
    # Preconditions
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Paths must be absolute"

    # Read metadata
    records = _read_metadata(raw)
    lavdf_root = raw / "LAV-DF"

    # Partition by split from metadata
    train_recs = [r for r in records if r.get("split") == "train" and str(r.get("file", "")).endswith(".mp4")]
    test_recs = [r for r in records if r.get("split") == "test" and str(r.get("file", "")).endswith(".mp4")]

    assert len(train_recs) > 0, "No training records found in metadata"
    assert len(test_recs) > 0, "No test records found in metadata"

    # Create directories
    (public / "train_videos").mkdir(parents=True, exist_ok=True)
    (public / "test_videos").mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Build CSV rows and copy/link files
    train_rows: List[Dict[str, Any]] = []
    test_rows_with_labels: List[Dict[str, Any]] = []

    # Train
    for r in train_recs:
        rel = Path(r["file"])  # e.g., train/000123.mp4
        src = lavdf_root / rel
        vid = rel.name
        dst = public / "train_videos" / vid
        _safe_link_or_copy(src, dst)
        mv = 1 if bool(r.get("modify_video", False)) else 0
        ma = 1 if bool(r.get("modify_audio", False)) else 0
        train_rows.append({
            "video_id": vid,
            "label_video_fake": mv,
            "label_audio_fake": ma,
        })

    # Test
    for r in test_recs:
        rel = Path(r["file"])  # e.g., test/000001.mp4
        src = lavdf_root / rel
        vid = rel.name
        dst = public / "test_videos" / vid
        _safe_link_or_copy(src, dst)
        mv = 1 if bool(r.get("modify_video", False)) else 0
        ma = 1 if bool(r.get("modify_audio", False)) else 0
        test_rows_with_labels.append({
            "video_id": vid,
            "label_video_fake": mv,
            "label_audio_fake": ma,
        })

    # Deterministic sort
    train_rows = sorted(train_rows, key=lambda x: x["video_id"])
    test_rows_with_labels = sorted(test_rows_with_labels, key=lambda x: x["video_id"])

    # Write CSVs
    train_df = pd.DataFrame(train_rows, columns=["video_id", "label_video_fake", "label_audio_fake"])
    test_answer_df = pd.DataFrame(test_rows_with_labels, columns=["video_id", "label_video_fake", "label_audio_fake"])

    train_df.to_csv(public / "train.csv", index=False)
    test_answer_df.to_csv(private / "test_answer.csv", index=False)

    # Public test.csv (ids only)
    pd.DataFrame({"video_id": test_answer_df["video_id"]}).to_csv(public / "test.csv", index=False)

    # Sample submission template (dummy 0.5 probs)
    sample_sub = pd.DataFrame({
        "video_id": test_answer_df["video_id"],
        "p_video_fake": 0.5,
        "p_audio_fake": 0.5,
    })
    sample_sub.to_csv(public / "sample_submission.csv", index=False)

    # Copy description.txt into public/
    repo_root = Path(__file__).resolve().parent
    desc_src = repo_root / "description.txt"
    if desc_src.exists():
        shutil.copy2(desc_src, public / "description.txt")

    # Checks
    assert (public / "train_videos").exists(), "public/train_videos should exist"
    assert (public / "test_videos").exists(), "public/test_videos should exist"

    # Files exist
    for vid in train_df["video_id"]:
        assert (public / "train_videos" / vid).exists(), f"Missing train video: {vid}"
    for vid in test_answer_df["video_id"]:
        assert (public / "test_videos" / vid).exists(), f"Missing test video: {vid}"

    # No overlap in ids
    assert set(train_df["video_id"]).isdisjoint(set(test_answer_df["video_id"])), "Train and test ids should not overlap"

    # Labels 0/1
    for col in ["label_video_fake", "label_audio_fake"]:
        assert set(train_df[col].unique()).issubset({0, 1}), f"Train labels in {col} must be 0/1"
        assert set(test_answer_df[col].unique()).issubset({0, 1}), f"Test labels in {col} must be 0/1"

    # Ensure positives and negatives exist in both splits
    def has_pos_neg(df: pd.DataFrame, col: str) -> bool:
        return df[col].sum() > 0 and (len(df) - df[col].sum()) > 0

    assert has_pos_neg(train_df, "label_video_fake"), "Train split must have pos/neg for video label"
    assert has_pos_neg(train_df, "label_audio_fake"), "Train split must have pos/neg for audio label"
    assert has_pos_neg(test_answer_df, "label_video_fake"), "Test split must have pos/neg for video label"
    assert has_pos_neg(test_answer_df, "label_audio_fake"), "Test split must have pos/neg for audio label"

    # Sample submission ids must match test ids
    ss = pd.read_csv(public / "sample_submission.csv")
    assert set(ss["video_id"]) == set(test_answer_df["video_id"]), "sample_submission ids must match test ids"

    # Public test.csv should be ids only
    tpub = pd.read_csv(public / "test.csv")
    assert tpub.columns.tolist() == ["video_id"], "public/test.csv must contain only [video_id]"
    assert set(tpub["video_id"]) == set(test_answer_df["video_id"]), "public/test.csv ids must match test ids"
