from pathlib import Path
import shutil
from typing import Tuple

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete data preparation process.

    - Read raw/preprocessed_content.csv
    - Deterministically split into train/test (stratified by binned total_score)
    - Write the following files/directories:
        public/
          - train.csv
          - test.csv
          - sample_submission.csv
          - description.txt
          - text/train/*.txt
          - text/test/*.txt
          - ner/train/*.txt
          - ner/test/*.txt
        private/
          - test_answer.csv
    """

    # Ensure absolute paths as requested
    raw = Path(str(raw)).resolve()
    public = Path(str(public)).resolve()
    private = Path(str(private)).resolve()

    # Create/clean output dirs
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Read raw data
    raw_csv = raw / "preprocessed_content.csv"
    assert raw_csv.exists(), f"Raw CSV not found: {raw_csv}"
    df = pd.read_csv(raw_csv)

    # Validate required columns
    required_cols = [
        "preprocessed_content",
        "e_score",
        "s_score",
        "g_score",
        "total_score",
    ]
    for c in required_cols:
        assert c in df.columns, f"Required column '{c}' missing from input CSV"

    # Optional columns
    ner_col = "ner_entities" if "ner_entities" in df.columns else None

    # Drop rows with missing key fields and coerce numeric targets
    df = df.dropna(subset=["preprocessed_content", "e_score", "s_score", "g_score", "total_score"]).copy()
    for c in ["e_score", "s_score", "g_score", "total_score"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=["e_score", "s_score", "g_score", "total_score"]).reset_index(drop=True)

    # Deterministic anonymized ids
    # Use a simple deterministic scheme: id_doc_<index:06d>
    df = df.reset_index(drop=True)
    df["id"] = [f"id_doc_{i:06d}" for i in range(len(df))]
    assert df["id"].is_unique

    # Stratified split on binned total_score
    rng_seed = 2024
    y = df["total_score"].values.astype(float)
    # Bin into quantiles (10 bins)
    quantiles = np.linspace(0, 1, 11)
    edges = np.unique(np.quantile(y, quantiles))
    # Use mid edges for digitize (exclude first/last)
    bins = edges[1:-1]
    y_bins = np.digitize(y, bins, right=False)

    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=rng_seed)
    idx_train, idx_test = next(splitter.split(df, y_bins))
    train_df = df.iloc[idx_train].copy()
    test_df = df.iloc[idx_test].copy()

    # Create directories
    text_train = public / "text" / "train"
    text_test = public / "text" / "test"
    ner_train = public / "ner" / "train"
    ner_test = public / "ner" / "test"
    for p in [text_train, text_test, ner_train, ner_test]:
        p.mkdir(parents=True, exist_ok=True)

    # Write text and ner files
    def _write(path: Path, content: str | float | int | None):
        if content is None or (isinstance(content, float) and np.isnan(content)):
            content = ""
        path.write_text(str(content), encoding="utf-8")

    for _, r in train_df.iterrows():
        _write(text_train / f"{r['id']}.txt", r["preprocessed_content"])
        _write(ner_train / f"{r['id']}.txt", r[ner_col] if ner_col else "")
    for _, r in test_df.iterrows():
        _write(text_test / f"{r['id']}.txt", r["preprocessed_content"])
        _write(ner_test / f"{r['id']}.txt", r[ner_col] if ner_col else "")

    # Write CSVs
    train_out = train_df[["id", "e_score", "s_score", "g_score", "total_score"]].copy()
    test_out = test_df[["id"]].copy()
    test_answer_out = test_df[["id", "e_score", "s_score", "g_score", "total_score"]].copy()

    train_out.to_csv(public / "train.csv", index=False)
    test_out.to_csv(public / "test.csv", index=False)
    test_answer_out.to_csv(private / "test_answer.csv", index=False)

    # Sample submission: deterministic random from train stats, clipped
    rng = np.random.RandomState(rng_seed)
    ss = test_out.copy()
    mins = train_out[["e_score", "s_score", "g_score", "total_score"]].min().values
    maxs = train_out[["e_score", "s_score", "g_score", "total_score"]].max().values
    means = train_out[["e_score", "s_score", "g_score", "total_score"]].mean().values
    stds = train_out[["e_score", "s_score", "g_score", "total_score"]].std(ddof=0).values
    stds = np.where(stds <= 1e-8, 1.0, stds)
    samples = rng.normal(loc=means, scale=stds, size=(len(ss), 4))
    samples = np.clip(samples, mins, maxs)
    ss[["e_score", "s_score", "g_score", "total_score"]] = samples
    ss.to_csv(public / "sample_submission.csv", index=False)

    # Copy or create concise description.txt to public/
    # If a root description.txt exists, copy; else write a minimal one.
    root_desc = raw.parent / "description.txt"
    dest_desc = public / "description.txt"
    if root_desc.exists():
        shutil.copy(root_desc, dest_desc)
    else:
        dest_desc.write_text(
            """
Title: ESG Score Prediction from Sustainability Reports

Files in this public package:
- train.csv: training targets with columns [id, e_score, s_score, g_score, total_score]
- test.csv: test ids (targets withheld)
- sample_submission.csv: example submission with required columns
- text/train/*.txt and text/test/*.txt: cleaned report text per id
- ner/train/*.txt and ner/test/*.txt: extracted named entities per id

Evaluation: Submissions are scored by Mean Columnwise RMSE (lower is better) over the four targets.
Submission format must match columns: id, e_score, s_score, g_score, total_score and ids must match test.csv.
            """.strip(),
            encoding="utf-8",
        )

    # Final checks
    # Existence
    for fp in [public / "train.csv", public / "test.csv", public / "sample_submission.csv", private / "test_answer.csv"]:
        assert fp.exists(), f"Missing file: {fp}"

    # ID sets
    train_ids = set(pd.read_csv(public / "train.csv")["id"])
    test_ids = set(pd.read_csv(public / "test.csv")["id"])
    ans_ids = set(pd.read_csv(private / "test_answer.csv")["id"])
    ss_ids = set(pd.read_csv(public / "sample_submission.csv")["id"])
    assert test_ids == ans_ids == ss_ids
    assert train_ids.isdisjoint(test_ids)

    # Text/NER counts
    for split, dir_base in [("train", text_train.parent), ("test", text_test.parent)]:
        t_dir = dir_base / split
        n_dir = (public / "ner") / split
        t_files = {p.name for p in t_dir.glob("*.txt")}
        n_files = {p.name for p in n_dir.glob("*.txt")}
        expected_ids = {f"{i}.txt" for i in (train_ids if split == "train" else test_ids)}
        assert t_files == n_files == expected_ids, f"Mismatch files for {split}"

    return None
