from __future__ import annotations

import csv
import math
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd


ID_COL = "ID"
TARGET_COL = "Duration_Minutes"


def _compute_duration_minutes(start_series: pd.Series, end_series: pd.Series) -> pd.Series:
    start_dt = pd.to_datetime(start_series, errors="coerce", utc=True)
    end_dt = pd.to_datetime(end_series, errors="coerce", utc=True)
    delta = (end_dt - start_dt).dt.total_seconds() / 60.0
    return delta


def _fit_running_log_stats(
    current_count: int, current_mean: float, current_M2: float, new_values: np.ndarray
) -> Tuple[int, float, float]:
    # Welford's algorithm on log1p of durations
    if new_values.size == 0:
        return current_count, current_mean, current_M2
    x = np.log1p(new_values)
    for val in x:
        current_count += 1
        delta = val - current_mean
        current_mean += delta / current_count
        delta2 = val - current_mean
        current_M2 += delta * delta2
    return current_count, current_mean, current_M2


def _finalize_stats(count: int, mean: float, M2: float) -> Tuple[float, float]:
    if count < 2:
        return mean, 1.0
    variance = M2 / (count - 1)
    variance = max(variance, 1e-6)
    return mean, math.sqrt(variance)


def _ensure_headers_written(path: Path, header: List[str]):
    if not path.exists():
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(header)


def _pick_data_file(raw: Path) -> Path:
    # Prefer the 2m sample, fall back to the full file if needed
    candidates = [
        raw / "us_congestion_2016_2022_sample_2m" / "us_congestion_2016_2022_sample_2m.csv",
        raw / "us_congestion_2016_2022" / "us_congestion_2016_2022.csv",
    ]
    for p in candidates:
        if p.exists():
            return p
    raise FileNotFoundError(
        f"Could not find data file in raw/. Looked for: {', '.join(str(p) for p in candidates)}"
    )


def prepare(raw: Path, public: Path, private: Path):
    # Deterministic behavior for any stochastic components
    rng = np.random.default_rng(seed=42)

    data_file = _pick_data_file(raw)

    # Define outputs
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_submission_csv = public / "sample_submission.csv"

    # Remove old outputs if they exist
    for p in [train_csv, test_csv, test_answer_csv, sample_submission_csv]:
        if p.exists():
            p.unlink()

    # Stream in chunks
    chunksize = 200_000

    # Running stats of log1p durations (train only)
    count = 0
    mean = 0.0
    M2 = 0.0

    first_chunk = True
    cutoff_year = 2022  # test is 2022+, train is prior years

    for chunk in pd.read_csv(data_file, chunksize=chunksize, low_memory=False):
        # Compute target
        dur_min = _compute_duration_minutes(chunk["StartTime"], chunk["EndTime"])  # type: ignore[index]
        chunk[TARGET_COL] = dur_min

        # Keep only valid positive durations
        valid = chunk[TARGET_COL].notna() & (chunk[TARGET_COL] > 0)
        chunk = chunk.loc[valid].copy()
        if chunk.empty:
            continue

        # Temporal split by StartTime year
        start_dt = pd.to_datetime(chunk["StartTime"], errors="coerce", utc=True)
        start_year = start_dt.dt.year

        if first_chunk:
            all_cols = list(chunk.columns)
            assert ID_COL in all_cols, f"Expected an '{ID_COL}' column in the dataset"
            assert "EndTime" in all_cols and "StartTime" in all_cols, "Expected StartTime and EndTime columns"

            # Feature columns: drop EndTime to avoid leakage; keep target for train only
            feature_cols = [c for c in all_cols if c != "EndTime"]

            # Train header should end with target column
            train_header = [c for c in feature_cols]
            if train_header[-1] != TARGET_COL:
                if TARGET_COL in train_header:
                    train_header = [c for c in train_header if c != TARGET_COL] + [TARGET_COL]
                else:
                    train_header.append(TARGET_COL)

            # Test header excludes target
            test_header = [c for c in feature_cols if c != TARGET_COL]

            # Answer/submission headers
            ans_header = [ID_COL, TARGET_COL]
            sub_header = [ID_COL, TARGET_COL]

            # Write headers
            _ensure_headers_written(train_csv, train_header)
            _ensure_headers_written(test_csv, test_header)
            _ensure_headers_written(test_answer_csv, ans_header)
            _ensure_headers_written(sample_submission_csv, sub_header)

            first_chunk = False

        # Split
        is_test = start_year >= cutoff_year
        test_part = chunk.loc[is_test].copy()
        train_part = chunk.loc[~is_test].copy()

        # Update running stats with train durations
        if not train_part.empty:
            count, mean, M2 = _fit_running_log_stats(
                count, mean, M2, train_part[TARGET_COL].to_numpy(dtype=float)
            )
            # Append train rows in the specified column order
            train_part = train_part[[c for c in train_header]]
            train_part.to_csv(train_csv, mode="a", header=False, index=False)

        # Handle test/answer/submission
        if not test_part.empty:
            # test.csv without target
            test_features = test_part[[c for c in test_header]].copy()
            test_features.to_csv(test_csv, mode="a", header=False, index=False)

            # test_answer.csv with [ID, target]
            test_answers = test_part[[ID_COL, TARGET_COL]].copy()
            test_answers.to_csv(test_answer_csv, mode="a", header=False, index=False)

            # sample_submission.csv baseline predictions using current running lognormal stats
            mu, sigma = _finalize_stats(count, mean, M2)
            sigma = max(sigma, 0.3)
            n = len(test_part)
            preds = rng.lognormal(mean=mu, sigma=sigma, size=n)
            # Limit extremely large predictions to a reasonable multiple of current observed target range
            max_true = float(np.nanmax(test_part[TARGET_COL].to_numpy(dtype=float)))
            if not math.isfinite(max_true) or max_true <= 0:
                max_true = 1_000.0
            preds = np.clip(preds, 0.0, max_true * 5.0)
            baseline = pd.DataFrame({ID_COL: test_part[ID_COL].values, TARGET_COL: preds})
            baseline.to_csv(sample_submission_csv, mode="a", header=False, index=False)

    # Basic checks
    assert train_csv.exists() and train_csv.stat().st_size > 0, "public/train.csv was not created"
    assert test_csv.exists() and test_csv.stat().st_size > 0, "public/test.csv was not created"
    assert test_answer_csv.exists() and test_answer_csv.stat().st_size > 0, "private/test_answer.csv was not created"
    assert sample_submission_csv.exists() and sample_submission_csv.stat().st_size > 0, "public/sample_submission.csv was not created"

    # Structure checks (peek few rows)
    train_head = pd.read_csv(train_csv, nrows=5)
    test_head = pd.read_csv(test_csv, nrows=5)
    ans_head = pd.read_csv(test_answer_csv, nrows=5)
    sub_head = pd.read_csv(sample_submission_csv, nrows=5)

    assert TARGET_COL in train_head.columns, "Target column missing from train.csv"
    assert TARGET_COL not in test_head.columns, "Target column must not appear in test.csv"
    assert "EndTime" not in train_head.columns, "EndTime should be removed to prevent leakage"
    assert "EndTime" not in test_head.columns, "EndTime should be removed to prevent leakage"

    assert ans_head.columns.tolist() == [ID_COL, TARGET_COL], "test_answer.csv must have [ID, Duration_Minutes] columns"
    assert sub_head.columns.tolist() == [ID_COL, TARGET_COL], "sample_submission.csv must have [ID, Duration_Minutes] columns"

    # Validate ID correspondence and counts
    test_ids = pd.read_csv(test_csv, usecols=[ID_COL])
    ans_ids = pd.read_csv(test_answer_csv, usecols=[ID_COL])
    sub_ids = pd.read_csv(sample_submission_csv, usecols=[ID_COL])

    assert len(test_ids) == len(ans_ids) == len(sub_ids), "test.csv, test_answer.csv, and sample_submission.csv must have the same number of rows"
    assert test_ids.equals(ans_ids), "IDs in test_answer.csv must match test.csv in the same order"
    assert set(test_ids[ID_COL]) == set(sub_ids[ID_COL]), "sample_submission IDs must match the test IDs (order can differ)"

    # Ground-truth validity
    ans_target = pd.read_csv(test_answer_csv, usecols=[TARGET_COL])
    assert np.isfinite(ans_target[TARGET_COL]).all(), "All ground-truth durations must be finite"
    assert (ans_target[TARGET_COL] > 0).all(), "All ground-truth durations must be positive"

    # Copy description.txt into public/ for participants
    # Expect a description.txt at repo root (workdir)
    repo_description = public.parent / "description.txt"
    if repo_description.exists():
        dest = public / "description.txt"
        dest.write_text(repo_description.read_text(encoding="utf-8"), encoding="utf-8")

