from pathlib import Path
from typing import List, Dict, Tuple
import re
import json

import numpy as np
import pandas as pd

# Deterministic settings
RANDOM_STATE = 17
TEST_FRACTION = 0.2
ID_COL = "id"
TARGET_COL = "Medicinal Properties"


def _slugify_label(label: str) -> str:
    base = str(label).strip()
    base = base.replace("µ", "u")
    base = re.sub(r"[^A-Za-z0-9]+", "_", base)
    base = base.strip("_")
    return f"label_{base}"


def _parse_labels(series: pd.Series) -> List[List[str]]:
    labs = []
    for x in series.fillna(""):
        tokens = [p.strip() for p in str(x).split(";") if str(p).strip()]
        labs.append(tokens)
    return labs


def _iterative_label_aware_split(y_lists: List[List[str]], test_frac: float, random_state: int) -> Tuple[np.ndarray, np.ndarray]:
    n = len(y_lists)
    rng = np.random.default_rng(seed=random_state)
    order = rng.permutation(n)

    # Count positives per label
    label_counts: Dict[str, int] = {}
    for labs in y_lists:
        for l in labs:
            label_counts[l] = label_counts.get(l, 0) + 1

    in_test = np.zeros(n, dtype=bool)
    in_train = np.ones(n, dtype=bool)

    target_test = int(round(test_frac * n))
    remaining_in_train = label_counts.copy()

    for idx in order:
        if in_test.sum() >= target_test:
            break
        labs = set(y_lists[idx])
        ok = True
        for l in labs:
            if remaining_in_train.get(l, 0) <= 1:
                ok = False
                break
        if ok:
            in_test[idx] = True
            in_train[idx] = False
            for l in labs:
                remaining_in_train[l] -= 1

    return in_train, in_test


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete preparation process.

    Inputs
    - raw: directory containing original data files (expects pfaf_plants_merged.csv)
    - public: output directory for all public files (train/test/sample_submission/description)
    - private: output directory for hidden files (test_answer.csv)
    """
    raw = Path(raw).resolve()
    public = Path(public).resolve()
    private = Path(private).resolve()

    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    raw_csv = raw / "pfaf_plants_merged.csv"
    assert raw_csv.exists(), f"Raw CSV not found at {raw_csv}"

    # Read and filter
    df = pd.read_csv(raw_csv)
    assert TARGET_COL in df.columns, f"Target column '{TARGET_COL}' missing in raw CSV"

    y_lists = _parse_labels(df[TARGET_COL])
    has_label = np.array([len(x) > 0 for x in y_lists])
    df = df.loc[has_label].reset_index(drop=True)
    y_lists = [x for x in y_lists if len(x) > 0]

    # Create deterministic ID and shuffle order deterministically
    rng = np.random.default_rng(RANDOM_STATE)
    perm = rng.permutation(len(df))
    df = df.iloc[perm].reset_index(drop=True)
    y_lists = [y_lists[i] for i in perm]
    df.insert(0, ID_COL, np.arange(1, len(df) + 1))

    # Split
    train_mask, test_mask = _iterative_label_aware_split(y_lists, TEST_FRACTION, RANDOM_STATE)
    train_df = df.loc[train_mask].copy().reset_index(drop=True)
    test_df = df.loc[test_mask].copy().reset_index(drop=True)

    train_y = [y_lists[i] for i in np.where(train_mask)[0]]
    test_y = [y_lists[i] for i in np.where(test_mask)[0]]

    # Label universe and mapping (for answer/submission columns)
    all_labels = sorted({l for labs in y_lists for l in labs})
    label_to_col = {l: _slugify_label(l) for l in all_labels}

    # Assertions and integrity checks
    assert len(train_df) + len(test_df) == len(df)
    assert set(train_df[ID_COL]).isdisjoint(set(test_df[ID_COL]))
    labels_in_train = sorted({l for labs in train_y for l in labs})
    labels_in_test = sorted({l for labs in test_y for l in labs})
    assert set(labels_in_test).issubset(set(labels_in_train)), "All test labels must appear in training"
    assert train_df[ID_COL].is_unique and test_df[ID_COL].is_unique

    # Save train.csv (with target) and test.csv (without target)
    train_out = train_df.copy()
    test_out = test_df.drop(columns=[TARGET_COL], errors="ignore")
    assert TARGET_COL in train_out.columns
    assert TARGET_COL not in test_out.columns

    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    train_out.to_csv(train_csv, index=False)
    test_out.to_csv(test_csv, index=False)

    # Build private test answers as binary matrix with sanitized columns
    answer_cols = [label_to_col[l] for l in all_labels]
    label_index = {l: i for i, l in enumerate(all_labels)}
    test_bin = np.zeros((len(test_df), len(all_labels)), dtype=int)
    for i, labs in enumerate(test_y):
        for l in labs:
            test_bin[i, label_index[l]] = 1

    ans_df = pd.DataFrame(test_bin, columns=answer_cols)
    ans_df.insert(0, ID_COL, test_out[ID_COL].values)
    test_answer_csv = private / "test_answer.csv"
    ans_df.to_csv(test_answer_csv, index=False)

    # Create sample_submission with random probabilities in (0.02, 0.98)
    sample_df = pd.DataFrame({ID_COL: test_out[ID_COL].values})
    for col in answer_cols:
        sample_df[col] = rng.uniform(0.02, 0.98, size=len(sample_df))
    sample_submission_csv = public / "sample_submission.csv"
    sample_df.to_csv(sample_submission_csv, index=False)

    # Save label mapping for reference (optional for users, but keep public)
    label_map_json = public / "label_map.json"
    with open(label_map_json, "w", encoding="utf-8") as f:
        json.dump(label_to_col, f, indent=2, ensure_ascii=False)

    # Copy description.txt into public (if exists in cwd)
    root_desc = Path(__file__).resolve().parent / "description.txt"
    if root_desc.exists():
        (public / "description.txt").write_text(root_desc.read_text(encoding="utf-8"), encoding="utf-8")

    # Post checks
    # 1) Required files
    assert train_csv.exists() and test_csv.exists(), "train.csv and test.csv must exist in public/"
    assert sample_submission_csv.exists(), "sample_submission.csv must exist in public/"
    assert test_answer_csv.exists(), "test_answer.csv must exist in private/"

    # 2) Columns alignment for sample and answers
    assert list(pd.read_csv(sample_submission_csv).columns) == [ID_COL] + answer_cols
    assert list(pd.read_csv(test_answer_csv).columns) == [ID_COL] + answer_cols

    # 3) No NaNs and values in range for submission
    sub_vals = pd.read_csv(sample_submission_csv).drop(columns=[ID_COL]).values
    assert np.isfinite(sub_vals).all()
    assert ((sub_vals >= 0.0) & (sub_vals <= 1.0)).all()

    # 4) Ensure ids match in sample and answers
    sub_ids = pd.read_csv(sample_submission_csv, usecols=[ID_COL])[ID_COL].values
    ans_ids = pd.read_csv(test_answer_csv, usecols=[ID_COL])[ID_COL].values
    assert np.array_equal(np.sort(sub_ids), np.sort(ans_ids)), "IDs must match between sample_submission and test_answer"
