import os
import csv
import ast
import random
from pathlib import Path
from typing import List, Dict, Set, Tuple

# Deterministic seed
RANDOM_SEED = 42

# Columns for output
TRAIN_COLUMNS = ["id", "title", "summary", "labels"]
TEST_COLUMNS = ["id", "title", "summary"]
SUB_COLUMNS = ["id", "labels"]


def _read_source_rows(path: Path) -> List[Dict[str, str]]:
    """Read a raw CSV file and normalize rows to a unified schema.

    Returns list of dicts with keys: title, summary, labels (space-separated)
    """
    rows: List[Dict[str, str]] = []
    with path.open("r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        if not reader.fieldnames:
            return rows
        # Detect column names robustly (case-insensitive)
        cols = {c.lower(): c for c in reader.fieldnames}
        title_col = cols.get("titles") or cols.get("title")
        summary_col = cols.get("summaries") or cols.get("summary") or cols.get("abstracts") or cols.get("abstract")
        terms_col = cols.get("terms") or cols.get("labels")
        if not (title_col and summary_col and terms_col):
            # Not a recognized schema; skip this file
            return rows
        for r in reader:
            title = (r.get(title_col, "") or "").strip()
            summary = (r.get(summary_col, "") or "").strip()
            terms_raw = (r.get(terms_col, "") or "").strip()
            if not title or not summary or not terms_raw:
                continue
            # Parse labels; prefer python-literal list, fallback to tokenization
            labels: List[str]
            try:
                terms_list = ast.literal_eval(terms_raw)
                if isinstance(terms_list, (list, tuple)):
                    labels = [str(x).strip() for x in terms_list if str(x).strip()]
                else:
                    labels = [t.strip() for t in terms_raw.replace(",", " ").split() if t.strip()]
            except Exception:
                labels = [t.strip() for t in terms_raw.replace(",", " ").split() if t.strip()]
            labels = [lb for lb in labels if lb]
            if not labels:
                continue
            rows.append({
                "title": title,
                "summary": summary,
                "labels": " ".join(sorted(set(labels))),
            })
    return rows


essential_sources = [
    "arxiv_data_210930-054931.csv",  # preferred
    "arxiv_data.csv",                 # fallback
]


def _load_all_rows(raw: Path) -> List[Dict[str, str]]:
    for fname in essential_sources:
        p = raw / fname
        if p.exists():
            rows = _read_source_rows(p)
            if rows:
                return rows
    # Try any other csv in raw if the above failed
    for p in sorted(raw.glob("*.csv")):
        rows = _read_source_rows(p)
        if rows:
            return rows
    raise FileNotFoundError(f"No valid source data file found in {raw}.")


def _build_label_set(rows: List[Dict[str, str]]) -> Set[str]:
    labels: Set[str] = set()
    for r in rows:
        labels.update(r["labels"].split())
    return labels


def _multi_label_split(rows: List[Dict[str, str]], test_frac: float = 0.2, seed: int = RANDOM_SEED) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
    idxs = list(range(len(rows)))
    rng = random.Random(seed)
    rng.shuffle(idxs)

    n_test = int(round(len(rows) * test_frac))
    test_set = set(idxs[:n_test])
    train_set = set(idxs[n_test:])

    def labels_in_indices(idset: Set[int]) -> Set[str]:
        s: Set[str] = set()
        for i in idset:
            s.update(rows[i]["labels"].split())
        return s

    all_labels = _build_label_set(rows)
    train_labels = labels_in_indices(train_set)

    missing = all_labels.intersection(labels_in_indices(test_set)) - train_labels
    if missing:
        for lb in sorted(missing):
            for i in list(test_set):
                if lb in rows[i]["labels"].split():
                    test_set.remove(i)
                    train_set.add(i)
                    break
        # Validate fixed
        train_labels = labels_in_indices(train_set)
        still_missing = all_labels.intersection(labels_in_indices(test_set)) - train_labels
        assert not still_missing, f"Failed to ensure label coverage in train for labels: {sorted(still_missing)}"

    train_rows = [rows[i] for i in sorted(train_set)]
    test_rows = [rows[i] for i in sorted(test_set)]
    return train_rows, test_rows


def _write_csv(path: Path, fieldnames: List[str], rows: List[Dict[str, str]]):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for r in rows:
            writer.writerow({k: r.get(k, "") for k in fieldnames})


def prepare(raw: Path, public: Path, private: Path):
    """Prepare competition data splits and files.

    Inputs:
      - raw: absolute path to raw/ directory containing the original CSV(s)
      - public: absolute path to public/ output directory
      - private: absolute path to private/ output directory

    Outputs (written exactly to):
      - public/train.csv
      - public/test.csv
      - public/sample_submission.csv
      - public/description.txt (copied from repository root if present)
      - private/test_answer.csv
    """
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Please pass absolute paths"

    rows = _load_all_rows(raw)

    # Deduplicate exact duplicates deterministically
    seen = set()
    deduped: List[Dict[str, str]] = []
    for r in rows:
        key = (r["title"], r["summary"], r["labels"])  # exact match
        if key not in seen:
            seen.add(key)
            deduped.append(r)
    rows = deduped

    # Split deterministically
    train_rows, test_rows = _multi_label_split(rows, test_frac=0.2, seed=RANDOM_SEED)

    # Assign deterministic numeric ids starting at 100000
    cur_id = 100000

    def assign_ids(rs: List[Dict[str, str]]):
        nonlocal cur_id
        for r in rs:
            r["id"] = str(cur_id)
            cur_id += 1

    assign_ids(train_rows)
    assign_ids(test_rows)

    # Write outputs to proper folders
    _write_csv(public / "train.csv", TRAIN_COLUMNS, train_rows)
    test_public_rows = [{"id": r["id"], "title": r["title"], "summary": r["summary"]} for r in test_rows]
    _write_csv(public / "test.csv", TEST_COLUMNS, test_public_rows)

    test_answer_rows = [{"id": r["id"], "labels": r["labels"]} for r in test_rows]
    _write_csv(private / "test_answer.csv", SUB_COLUMNS, test_answer_rows)

    # Sample submission from train label vocab
    train_label_list = sorted(_build_label_set(train_rows))
    rng = random.Random(RANDOM_SEED)

    def random_labels() -> str:
        k = rng.randint(0, 3)
        if k == 0 or not train_label_list:
            return ""
        return " ".join(sorted(set(rng.sample(train_label_list, k=min(k, len(train_label_list))))))

    sample_rows = [{"id": r["id"], "labels": random_labels()} for r in test_rows]
    _write_csv(public / "sample_submission.csv", SUB_COLUMNS, sample_rows)

    # Copy description.txt into public/ if present in repo root
    repo_root_desc = Path(__file__).resolve().parent / "description.txt"
    if repo_root_desc.exists():
        (public / "description.txt").write_text(repo_root_desc.read_text(encoding="utf-8"), encoding="utf-8")

    # Checks and assertions
    # 1) Files exist
    assert (public / "train.csv").exists(), "Missing public/train.csv"
    assert (public / "test.csv").exists(), "Missing public/test.csv"
    assert (public / "sample_submission.csv").exists(), "Missing public/sample_submission.csv"
    assert (private / "test_answer.csv").exists(), "Missing private/test_answer.csv"

    # 2) Basic sizes
    assert len(train_rows) > 0 and len(test_rows) > 0, "Empty split"
    assert len(train_rows) + len(test_rows) == len(rows), "Split size mismatch"

    # 3) Column checks
    def read_header(p: Path) -> List[str]:
        with p.open("r", encoding="utf-8") as f:
            return next(csv.reader(f))

    assert read_header(public / "train.csv") == TRAIN_COLUMNS
    assert read_header(public / "test.csv") == TEST_COLUMNS
    assert read_header(private / "test_answer.csv") == SUB_COLUMNS
    assert read_header(public / "sample_submission.csv") == SUB_COLUMNS

    # 4) ID alignment: test.csv ids match test_answer ids exactly and no duplicates
    def read_ids_labels(p: Path) -> Tuple[List[str], Dict[str, str]]:
        with p.open("r", encoding="utf-8", newline="") as f:
            reader = csv.DictReader(f)
            ids: List[str] = []
            labels: Dict[str, str] = {}
            for r in reader:
                ids.append(r["id"])
                if "labels" in r:
                    labels[r["id"]] = r["labels"].strip()
            return ids, labels

    test_ids, _ = read_ids_labels(public / "test.csv")
    ans_ids, ans_labels = read_ids_labels(private / "test_answer.csv")
    assert test_ids == ans_ids, "public/test.csv and private/test_answer.csv id order mismatch"
    assert len(set(test_ids)) == len(test_ids), "Duplicate ids in public/test.csv"

    # 5) Train/test disjointness
    train_ids, _ = read_ids_labels(public / "train.csv")
    assert set(train_ids).isdisjoint(set(test_ids)), "Train and test ids overlap"

    # 6) All test label tokens occur at least once in training
    train_label_tokens = _build_label_set(train_rows)
    test_label_tokens: Set[str] = set()
    for lb in ans_labels.values():
        test_label_tokens.update(lb.split())
    assert test_label_tokens.issubset(train_label_tokens), "Found test labels absent from training"

    # 7) Sanity check on labels formatting and id format
    def tokens_ok(label_str: str) -> bool:
        tokens = [t for t in label_str.split() if t]
        return all(" " not in t for t in tokens)

    for r in train_rows:
        assert tokens_ok(r["labels"]), "Invalid label token with whitespace in train"
    for lb in ans_labels.values():
        assert tokens_ok(lb), "Invalid label token with whitespace in test_answer"

    # IDs numeric
    for pid in train_ids + test_ids:
        assert pid.isdigit(), f"Non-numeric id encountered: {pid}"
