from pathlib import Path
import shutil
import csv
from typing import Tuple
import pandas as pd

# This prepare() builds the public/private artifacts for the CNN/DailyMail summarization task.
# Inputs (in raw/): cnn_dailymail/train.csv, cnn_dailymail/validation.csv, cnn_dailymail/test.csv
# Outputs:
#   public/
#     - train.csv            (columns: id, article, summary)  [train + validation]
#     - test.csv             (columns: id, article)
#     - sample_submission.csv(columns: id, summary)
#     - description.txt      (copied from repository root if present)
#   private/
#     - test_answer.csv      (columns: id, summary)

CHUNKSIZE = 2000


def _safe_to_csv(df: pd.DataFrame, path: Path, mode: str, header: bool):
    # Ensure string dtype and no actual NaNs in outputs; allow empty strings
    for c in df.columns:
        df[c] = df[c].astype(str)
        df[c] = df[c].where(df[c].notna(), "")
    df.to_csv(path, index=False, mode=mode, header=header, quoting=csv.QUOTE_MINIMAL)


def prepare(raw: Path, public: Path, private: Path):
    # Ensure directories
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Locate raw csvs
    src_dir = raw / "cnn_dailymail"
    train_src = src_dir / "train.csv"
    valid_src = src_dir / "validation.csv"
    test_src = src_dir / "test.csv"

    assert train_src.exists(), f"Missing raw file: {train_src}"
    assert valid_src.exists(), f"Missing raw file: {valid_src}"
    assert test_src.exists(), f"Missing raw file: {test_src}"

    # Targets in public/private
    train_out = public / "train.csv"
    test_out = public / "test.csv"
    sample_out = public / "sample_submission.csv"
    ans_out = private / "test_answer.csv"

    # Clean any previous outputs
    for p in [train_out, test_out, sample_out, ans_out]:
        if p.exists():
            p.unlink()

    # Build train.csv from train + validation
    header_written = False
    for src in [train_src, valid_src]:
        for chunk in pd.read_csv(src, chunksize=CHUNKSIZE):
            assert {"id", "article", "highlights"}.issubset(set(chunk.columns)), (
                f"Unexpected columns in {src}: {chunk.columns}")
            df = chunk[["id", "article", "highlights"]].rename(columns={"highlights": "summary"})
            df["id"] = df["id"].astype(str)
            df["article"] = df["article"].fillna("").astype(str)
            df["summary"] = df["summary"].fillna("").astype(str)
            _safe_to_csv(df, train_out, mode="a", header=(not header_written))
            header_written = True

    # Build test.csv (id, article), test_answer.csv (id, summary), sample_submission.csv
    header_test = False
    header_ans = False
    header_sub = False

    for chunk in pd.read_csv(test_src, chunksize=CHUNKSIZE):
        assert {"id", "article", "highlights"}.issubset(set(chunk.columns)), (
            f"Unexpected columns in {test_src}: {chunk.columns}")
        df = chunk[["id", "article", "highlights"]].rename(columns={"highlights": "summary"})
        df["id"] = df["id"].astype(str)
        df["article"] = df["article"].fillna("").astype(str)
        df["summary"] = df["summary"].fillna("").astype(str)

        # public/test.csv
        df_test = df[["id", "article"]]
        _safe_to_csv(df_test, test_out, mode="a", header=(not header_test))
        header_test = True

        # private/test_answer.csv
        df_ans = df[["id", "summary"]]
        _safe_to_csv(df_ans, ans_out, mode="a", header=(not header_ans))
        header_ans = True

        # public/sample_submission.csv (simple heuristic baseline)
        preds = []
        for art in df_test["article"].tolist():
            toks = str(art).split()
            preds.append(" ".join(toks[:30]) if toks else "")
        df_sub = pd.DataFrame({"id": df_test["id"].tolist(), "summary": preds})
        _safe_to_csv(df_sub, sample_out, mode="a", header=(not header_sub))
        header_sub = True

    # Copy description.txt to public if exists in repo root
    repo_root_desc = Path(__file__).resolve().parent / "description.txt"
    if repo_root_desc.exists():
        shutil.copy(repo_root_desc, public / "description.txt")

    # Integrity checks
    # Existence
    assert train_out.exists(), "public/train.csv should exist"
    assert test_out.exists(), "public/test.csv should exist"
    assert sample_out.exists(), "public/sample_submission.csv should exist"
    assert ans_out.exists(), "private/test_answer.csv should exist"

    # Column schemas
    t_head = pd.read_csv(test_out, nrows=5)
    a_head = pd.read_csv(ans_out, nrows=5)
    s_head = pd.read_csv(sample_out, nrows=5)
    tr_head = pd.read_csv(train_out, nrows=5)

    assert list(tr_head.columns) == ["id", "article", "summary"], f"Unexpected columns in train.csv: {tr_head.columns}"
    assert list(t_head.columns) == ["id", "article"], f"Unexpected columns in test.csv: {t_head.columns}"
    assert list(a_head.columns) == ["id", "summary"], f"Unexpected columns in test_answer.csv: {a_head.columns}"
    assert list(s_head.columns) == ["id", "summary"], f"Unexpected columns in sample_submission.csv: {s_head.columns}"

    # ID set consistency between test and others (answer and sample)
    test_df = pd.read_csv(test_out)
    ans_df = pd.read_csv(ans_out)
    sub_df = pd.read_csv(sample_out)
    assert set(test_df.id) == set(ans_df.id), "Mismatch of ids between test.csv and test_answer.csv"
    assert set(test_df.id) == set(sub_df.id), "Mismatch of ids between test.csv and sample_submission.csv"

    # No label leakage in public/test.csv
    assert "summary" not in test_df.columns, "Label leakage detected in public/test.csv"
