from pathlib import Path
from typing import List, Tuple
import pandas as pd
import numpy as np

# This module exposes exactly one function as required:
# def prepare(raw: Path, public: Path, private: Path)
# It reads raw data from raw/, creates train/test splits and writes
# - public/train.csv
# - public/test.csv
# - public/sample_submission.csv
# - private/test_answer.csv
# - public/description.txt
# No main() entry point is included.

RANDOM_SEED = 42

# Columns that may leak target or are irrelevant for modeling
SALARY_LEAK_COLS = [
    "max_salary",
    "med_salary",
    "min_salary",
    "pay_period",
    "currency",
    "compensation_type",
    "normalized_salary",
]
URL_COLS = ["job_posting_url", "application_url"]

# Preferred feature columns to keep files compact but informative
PREFERRED_FEATURE_COLS = [
    "job_id",
    "company_id",
    "company_name",
    "title",
    "description",
    "skills_desc",
    "formatted_work_type",
    "work_type",
    "formatted_experience_level",
    "location",
    "zip_code",
    "fips",
    "remote_allowed",
    "views",
    "applies",
    "posting_domain",
    "application_type",
    "sponsored",
    "listed_time",
    "original_listed_time",
    "expiry",
    "closed_time",
]


def _read_subset_postings(postings_path: Path, target_rows: int = 6000) -> pd.DataFrame:
    """Read a manageable subset of postings.csv with valid labels in USD.

    This streams postings.csv in chunks and collects rows with:
    - currency == 'USD'
    - valid, positive normalized_salary
    - at least one of listed_time/original_listed_time present
    """
    collected = []

    # Iterate in chunks to avoid loading the entire dataset into memory
    for chunk in pd.read_csv(postings_path, chunksize=100_000):
        # Ensure required columns exist
        required = {"job_id", "normalized_salary"}
        missing = required - set(chunk.columns)
        if missing:
            raise AssertionError(f"postings.csv missing required columns: {missing}")

        # Cast types and filter
        if "currency" in chunk.columns:
            chunk = chunk[chunk["currency"] == "USD"]
        chunk["normalized_salary"] = pd.to_numeric(chunk["normalized_salary"], errors="coerce")
        chunk = chunk[chunk["normalized_salary"].notna() & (chunk["normalized_salary"] > 0)]

        # Build a split time column
        for c in ("listed_time", "original_listed_time"):
            if c in chunk.columns:
                chunk[c] = pd.to_numeric(chunk[c], errors="coerce")
        if "listed_time" in chunk.columns:
            split_time = chunk["listed_time"].copy()
        else:
            split_time = pd.Series(np.nan, index=chunk.index)
        if "original_listed_time" in chunk.columns:
            split_time = split_time.fillna(chunk["original_listed_time"])  # type: ignore
        chunk["__split_time"] = split_time
        chunk = chunk[chunk["__split_time"].notna()]

        if len(chunk) == 0:
            continue

        collected.append(chunk)
        total = sum(len(x) for x in collected)
        if total >= target_rows:
            break

    if not collected:
        raise AssertionError("No valid rows found in postings.csv under the required filters.")

    df = pd.concat(collected, axis=0, ignore_index=True)

    # De-duplicate by job_id keeping the latest listing
    df = df.sort_values(["job_id", "__split_time"]).drop_duplicates(subset=["job_id"], keep="last")

    # Return full df including label and split time; feature selection is later
    return df


def _time_split(df: pd.DataFrame, quantile: float = 0.8) -> Tuple[list[int], list[int]]:
    cutoff = float(df["__split_time"].quantile(quantile))
    train_ids = df.loc[df["__split_time"] < cutoff, "job_id"].astype(int).tolist()
    test_ids = df.loc[df["__split_time"] >= cutoff, "job_id"].astype(int).tolist()
    assert set(train_ids).isdisjoint(set(test_ids)), "Train/Test overlap detected"
    # Ensure we have reasonable sizes
    assert len(train_ids) >= 100 and len(test_ids) >= 50, "Insufficient samples for train/test split"
    return train_ids, test_ids


def _write_description(public: Path, train_cols: list[str], test_cols: list[str]):
    lines = [
        "Competition files (public):",
        "- train.csv: Training features and target (target_salary).",
        "- test.csv: Test features without target.",
        "- sample_submission.csv: Example submission with columns [job_id, target_salary].",
        "",
        "Columns:",
        f"- Train columns: {train_cols}",
        f"- Test columns: {test_cols}",
    ]
    (public / "description.txt").write_text("\n".join(lines), encoding="utf-8")


def prepare(raw: Path, public: Path, private: Path):
    # Ensure directories
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    postings_path = raw / "postings.csv"
    assert postings_path.exists(), f"Missing source postings: {postings_path}"

    # Read subset and split
    df = _read_subset_postings(postings_path, target_rows=7000)

    # Build features (drop leakage columns) and labels
    feats = df.copy()
    labels = feats[["job_id", "normalized_salary"]].rename(columns={"normalized_salary": "target_salary"})

    for c in SALARY_LEAK_COLS + URL_COLS:
        if c in feats.columns:
            feats = feats.drop(columns=[c])

    # Optionally reduce to preferred feature columns if present, plus job_id
    keep_cols = [c for c in PREFERRED_FEATURE_COLS if c in feats.columns]
    if "job_id" not in keep_cols:
        keep_cols = ["job_id"] + keep_cols
    feats = feats[keep_cols]

    # Time-based split
    train_ids, test_ids = _time_split(df)
    train_ids_set, test_ids_set = set(train_ids), set(test_ids)

    train_feats = feats[feats["job_id"].isin(train_ids_set)].copy()
    test_feats = feats[feats["job_id"].isin(test_ids_set)].copy()

    train_labels = labels[labels["job_id"].isin(train_ids_set)].copy()
    test_labels = labels[labels["job_id"].isin(test_ids_set)].copy()

    # Merge labels into train features for public train.csv
    train = train_feats.merge(train_labels, on="job_id", how="left")
    # test.csv should NOT contain target
    test = test_feats.copy()

    # Sort by time for reproducibility
    train = train.merge(df[["job_id", "__split_time"]], on="job_id", how="left").sort_values("__split_time").drop(columns=["__split_time"])  # type: ignore
    test = test.merge(df[["job_id", "__split_time"]], on="job_id", how="left").sort_values("__split_time").drop(columns=["__split_time"])  # type: ignore

    # Save outputs to exact locations
    train.to_csv(public / "train.csv", index=False)
    test.to_csv(public / "test.csv", index=False)

    # Private answers
    test_labels_sorted = test_labels.sort_values("job_id")[["job_id", "target_salary"]]
    test_labels_sorted.to_csv(private / "test_answer.csv", index=False)

    # Sample submission: random reasonable predictions sampled from train label distribution (log-normal)
    y = train_labels["target_salary"].astype(float).clip(lower=1.0)
    logy = np.log1p(y)
    mu, sigma = float(logy.mean()), float(logy.std(ddof=0) + 1e-6)
    sample = test[["job_id"]].copy()
    sample["target_salary"] = np.expm1(np.random.normal(mu, sigma, size=len(sample)))
    sample["target_salary"] = sample["target_salary"].replace([np.inf, -np.inf], np.nan).fillna(y.median()).clip(lower=0)
    sample.to_csv(public / "sample_submission.csv", index=False)

    # Write public description aligned with actual files
    _write_description(public, train.columns.tolist(), test.columns.tolist())

    # Integrity checks
    # Required files
    assert (public / "train.csv").exists(), "public/train.csv should exist"
    assert (public / "test.csv").exists(), "public/test.csv should exist"
    assert (public / "sample_submission.csv").exists(), "public/sample_submission.csv should exist"
    assert (public / "description.txt").exists(), "public/description.txt should exist"
    assert (private / "test_answer.csv").exists(), "private/test_answer.csv should exist"

    # Column expectations
    assert "job_id" in train.columns and "target_salary" in train.columns, "train must contain job_id and target_salary"
    assert "job_id" in test.columns and "target_salary" not in test.columns, "test must contain job_id and no target_salary"
    ss = pd.read_csv(public / "sample_submission.csv")
    assert list(ss.columns) == ["job_id", "target_salary"], "sample_submission must have ['job_id','target_salary']"

    # No overlap
    assert set(train["job_id"]).isdisjoint(set(test["job_id"])), "Train and test job_id sets must be disjoint"

    # Alignment with answers
    answers = pd.read_csv(private / "test_answer.csv")
    assert list(answers.columns) == ["job_id", "target_salary"], "test_answer must have ['job_id','target_salary']"
    assert set(test["job_id"]) == set(answers["job_id"]) == set(ss["job_id"])  # consistent ids

    # Basic sanity on values
    assert np.isfinite(answers["target_salary"]).all() and (answers["target_salary"] >= 0).all()
    assert np.isfinite(ss["target_salary"]).all() and (ss["target_salary"] >= 0).all()

    # Uniqueness
    assert train["job_id"].is_unique and test["job_id"].is_unique and answers["job_id"].is_unique
