from pathlib import Path
import random
import math
import shutil
from typing import Tuple

import pandas as pd

# Deterministic seed for all randomness in this module
RANDOM_SEED = 42
random.seed(RANDOM_SEED)

RAW_FILE_NAME = "arabic_nli_26lang_raw_combined.csv"

REQUIRED_COLS = ["premise", "hypothesis", "label"]
VALID_LABELS = {0, 1, 2}


def _clean_text(s: str) -> str:
    if isinstance(s, str):
        # Normalize whitespace and strip
        return " ".join(s.split()).strip()
    return ""


def _validate_and_clean_df(df: pd.DataFrame) -> pd.DataFrame:
    # Ensure required columns exist
    for c in REQUIRED_COLS:
        assert c in df.columns, f"Missing required column in raw data: {c}"

    # Keep only required columns
    df = df[REQUIRED_COLS].copy()

    # Drop rows with nulls in essential columns
    df = df.dropna(subset=["premise", "hypothesis", "label"]).copy()

    # Cast label to int and validate
    try:
        df["label"] = pd.to_numeric(df["label"], errors="raise").astype(int)
    except Exception as e:
        raise AssertionError(f"Label column could not be cast to int: {e}")
    bad = set(df["label"].unique()) - VALID_LABELS
    assert not bad, f"Found invalid labels: {bad}"

    # Clean text fields and drop empties
    df["premise"] = df["premise"].astype(str).map(_clean_text)
    df["hypothesis"] = df["hypothesis"].astype(str).map(_clean_text)
    df = df[(df["premise"].str.len() > 0) & (df["hypothesis"].str.len() > 0)].copy()

    # Sanity size check
    assert len(df) > 1000, f"Dataset too small after cleaning: {len(df)}"

    return df


def _resolve_conflicts(df: pd.DataFrame) -> pd.DataFrame:
    """Ensure unique (premise, hypothesis) pairs by majority-vote label.
    If multiple labels exist for the same pair, choose the label with highest count; in
    case of ties, choose the smallest label id deterministically.
    """
    cnt = (
        df.groupby(["premise", "hypothesis", "label"]).size().reset_index(name="count")
    )
    cnt.sort_values(
        ["premise", "hypothesis", "count", "label"],
        ascending=[True, True, False, True],
        inplace=True,
    )
    best = cnt.groupby(["premise", "hypothesis"], as_index=False).first()[
        ["premise", "hypothesis", "label"]
    ]
    assert best.duplicated(subset=["premise", "hypothesis"]).sum() == 0
    return best.reset_index(drop=True)


def _stratified_split(df: pd.DataFrame, test_size: float = 0.2) -> Tuple[pd.DataFrame, pd.DataFrame]:
    rng = random.Random(RANDOM_SEED)
    train_parts = []
    test_parts = []
    for y, grp in df.groupby("label", sort=True):
        n = len(grp)
        n_test = int(round(n * test_size))
        n_test = min(max(1, n_test), n - 1)  # Ensure both splits non-empty per class
        idx = list(grp.index)
        rng.shuffle(idx)
        test_idx = set(idx[:n_test])
        train_idx = [i for i in idx if i not in test_idx]
        train_parts.append(df.loc[train_idx])
        test_parts.append(df.loc[list(test_idx)])
    train_df = pd.concat(train_parts, axis=0).sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)
    test_df = pd.concat(test_parts, axis=0).sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)
    return train_df, test_df


def prepare(raw: Path, public: Path, private: Path):
    """Complete preparation process.

    Reads the raw dataset, validates, cleans, resolves conflicts, performs a deterministic
    stratified split, and writes the outputs:
      - public/train.csv, public/test.csv, public/sample_submission.csv, public/description.txt
      - private/test_answer.csv

    Notes:
    - raw contains the original data files and is not modified.
    - Splitting is deterministic with a fixed seed.
    """
    # Ensure absolute paths
    raw = raw.resolve()
    public = public.resolve()
    private = private.resolve()

    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    raw_file = raw / RAW_FILE_NAME
    assert raw_file.exists(), f"Raw file not found: {raw_file}"

    # Load raw CSV
    df_raw = pd.read_csv(raw_file)

    # Validate and clean
    df = _validate_and_clean_df(df_raw)

    # Resolve conflicts to ensure unique pairs
    before = len(df)
    df = _resolve_conflicts(df)
    after_conflict = len(df)
    assert after_conflict >= 0.60 * before, (
        f"Too many rows removed during conflict resolution: kept {after_conflict} of {before}."
    )

    # Deterministic stratified split
    train_df, test_df = _stratified_split(df, test_size=0.2)

    # Deterministic IDs assigned after split
    train_df = train_df.copy()
    test_df = test_df.copy()
    train_df["id"] = range(0, len(train_df))
    test_df["id"] = range(len(train_df), len(train_df) + len(test_df))

    # Outputs
    train_out = train_df[["id", "premise", "hypothesis", "label"]].copy()
    test_out = test_df[["id", "premise", "hypothesis"]].copy()
    test_ans = test_df[["id", "label"]].copy()

    # Integrity checks
    assert set(train_out["id"]).isdisjoint(set(test_out["id"])), "Train/Test id overlap detected."

    train_pairs = set(zip(train_out["premise"], train_out["hypothesis"]))
    test_pairs = set(zip(test_out["premise"], test_out["hypothesis"]))
    assert train_pairs.isdisjoint(test_pairs), "Train/Test text pair overlap detected."

    assert set(train_out["label"]).issubset(VALID_LABELS), "Invalid labels in train."
    assert set(test_ans["label"]).issubset(VALID_LABELS), "Invalid labels in test answers."

    assert set(test_ans["label"]).issubset(set(train_out["label"])), "Some test labels are missing in train."

    assert len(test_out) == len(test_ans), "Test and test_answer length mismatch."
    assert set(test_out["id"]) == set(test_ans["id"]), "ID mismatch between test and test_answer."

    assert not train_out.isnull().any().any(), "Nulls in train output."
    assert not test_out.isnull().any().any(), "Nulls in test output."
    assert not test_ans.isnull().any().any(), "Nulls in test_answer."

    # Write CSVs to their respective directories
    train_path = public / "train.csv"
    test_path = public / "test.csv"
    test_ans_path = private / "test_answer.csv"
    sample_sub_path = public / "sample_submission.csv"

    train_out.to_csv(train_path, index=False)
    test_out.to_csv(test_path, index=False)
    test_ans.to_csv(test_ans_path, index=False)

    # Create deterministic sample submission with random labels from train label set
    rng = random.Random(RANDOM_SEED)
    label_choices = sorted(train_out["label"].unique().tolist())
    sample_labels = [rng.choice(label_choices) for _ in range(len(test_out))]
    sample_sub = pd.DataFrame({"id": test_out["id"], "label": sample_labels})

    assert set(sample_sub.columns) == {"id", "label"}
    assert len(sample_sub) == len(test_out)
    assert set(sample_sub["id"]) == set(test_out["id"])  # ids align
    assert set(sample_sub["label"]).issubset(VALID_LABELS)

    sample_sub.to_csv(sample_sub_path, index=False)

    # Provide description.txt in public directory
    repo_description = (public.parent / "description.txt").resolve()
    dest_description = public / "description.txt"
    if repo_description.exists():
        shutil.copy(repo_description, dest_description)

    # Post-conditions: required files exist
    assert train_path.exists(), f"Missing {train_path}"
    assert test_path.exists(), f"Missing {test_path}"
    assert test_ans_path.exists(), f"Missing {test_ans_path}"
    assert sample_sub_path.exists(), f"Missing {sample_sub_path}"
