from pathlib import Path
import hashlib
import math
from typing import Tuple

import numpy as np
import pandas as pd

# Determinism
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

REVIEWS_FILES = [
    "reviews_0-250.csv",
    "reviews_250-500.csv",
    "reviews_500-750.csv",
    "reviews_750-1250.csv",
    "reviews_1250-end.csv",
]
PRODUCTS_FILE = "product_info.csv"


def _read_reviews_csv(path: Path) -> pd.DataFrame:
    df = pd.read_csv(path, low_memory=False)
    # Drop accidental unnamed index columns
    drop_cols = [c for c in df.columns if c.startswith("Unnamed") or c == ""]
    if drop_cols:
        df = df.drop(columns=drop_cols)
    # Basic required columns
    required_cols = [
        "author_id",
        "rating",
        "is_recommended",
        "helpfulness",
        "total_feedback_count",
        "total_neg_feedback_count",
        "total_pos_feedback_count",
        "submission_time",
        "review_text",
        "review_title",
        "skin_tone",
        "eye_color",
        "skin_type",
        "hair_color",
        "product_id",
        "product_name",
        "brand_name",
        "price_usd",
    ]
    missing = [c for c in required_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns in reviews file {path}: {missing}")
    return df


def _read_products_csv(path: Path) -> pd.DataFrame:
    pdf = pd.read_csv(path, low_memory=False)
    rename_map = {}
    if "rating" in pdf.columns:
        rename_map["rating"] = "product_avg_rating"
    if "reviews" in pdf.columns:
        rename_map["reviews"] = "product_review_count"
    if "product_name" in pdf.columns:
        rename_map["product_name"] = "product_name_catalog"
    if "brand_name" in pdf.columns:
        rename_map["brand_name"] = "brand_name_catalog"
    if "price_usd" in pdf.columns:
        rename_map["price_usd"] = "price_usd_catalog"
    pdf = pdf.rename(columns=rename_map)
    if "product_id" not in pdf.columns:
        raise ValueError("product_info.csv must have product_id column")
    return pdf


def _make_review_id(df: pd.DataFrame) -> pd.Series:
    parts = (
        df["author_id"].astype(str).fillna(""),
        df["product_id"].astype(str).fillna(""),
        df["submission_time"].astype(str).fillna(""),
    )
    base = (parts[0] + "|" + parts[1] + "|" + parts[2]).astype(str)
    text_hash = (
        df.get("review_text", pd.Series([""] * len(df))).astype(str)
        .apply(lambda s: hashlib.md5(s.encode("utf-8", errors="ignore")).hexdigest()[:12])
    )
    tokens = (base + "|" + text_hash).values
    hashes = [hashlib.md5(t.encode("utf-8", errors="ignore")).hexdigest() for t in tokens]

    seen = {}
    out_ids = []
    for h in hashes:
        if h not in seen:
            seen[h] = 0
            out_ids.append(f"R_{h}")
        else:
            seen[h] += 1
            out_ids.append(f"R_{h}_{seen[h]}")
    return pd.Series(out_ids, index=df.index)


def _split_and_write(
    merged: pd.DataFrame,
    public: Path,
    private: Path,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    # Construct review_id
    merged = merged.copy()
    merged.insert(0, "review_id", _make_review_id(merged))

    # Sort chronologically then by review_id for determinism
    merged["_ts"] = pd.to_datetime(merged["submission_time"], errors="coerce")
    min_ts = pd.Timestamp("1900-01-01")
    merged.loc[merged["_ts"].isna(), "_ts"] = min_ts
    merged = merged.sort_values(["_ts", "review_id"]).reset_index(drop=True)

    # 80/20 split deterministically by order
    n = len(merged)
    test_size = int(math.floor(n * 0.20))
    train_size = n - test_size
    train_df = merged.iloc[:train_size].copy()
    test_df_full = merged.iloc[train_size:].copy()

    # Basic label cleaning/validation
    train_df["rating"] = pd.to_numeric(train_df["rating"], errors="coerce").round().clip(1, 5).astype(int)
    test_df_full["rating"] = pd.to_numeric(test_df_full["rating"], errors="coerce").round().clip(1, 5).astype(int)

    # Ensure label coverage
    def label_set(df):
        return set(df["rating"].astype(int).unique().tolist())

    all_labels = set([1, 2, 3, 4, 5])
    if not all_labels.issubset(label_set(train_df)):
        move_idx = 0
        while not all_labels.issubset(label_set(train_df)) and move_idx < len(test_df_full):
            train_df = pd.concat([train_df, test_df_full.iloc[[move_idx]]], ignore_index=True)
            move_idx += 1
        test_df_full = test_df_full.iloc[move_idx:].copy()

    # Final safety checks
    assert len(train_df) + len(test_df_full) == n, "Train+Test size mismatch"
    assert train_df["review_id"].is_unique and test_df_full["review_id"].is_unique
    assert set(train_df["review_id"]).isdisjoint(set(test_df_full["review_id"]))

    # Prepare outputs
    test_answer = test_df_full[["review_id", "rating"]].copy()

    # Drop helper
    for df in (train_df, test_df_full):
        if "_ts" in df.columns:
            df.drop(columns=["_ts"], inplace=True)

    # Reorder columns: id first; rating last in train; rating removed in test
    def reorder_train(df: pd.DataFrame) -> pd.DataFrame:
        cols = list(df.columns)
        cols.remove("review_id")
        cols.remove("rating")
        return df[["review_id"] + cols + ["rating"]]

    def reorder_test(df: pd.DataFrame) -> pd.DataFrame:
        cols = list(df.columns)
        cols.remove("review_id")
        cols.remove("rating")
        return df[["review_id"] + cols]

    train_out = reorder_train(train_df)
    test_out = reorder_test(test_df_full)

    # Sample submission from train label distribution
    label_counts = train_out["rating"].value_counts().sort_index()
    label_values = label_counts.index.values
    label_probs = (label_counts / label_counts.sum()).values
    rng = np.random.default_rng(RANDOM_SEED)
    sample_preds = rng.choice(label_values, size=len(test_out), p=label_probs)
    sample_sub = pd.DataFrame({"review_id": test_out["review_id"].values, "rating": sample_preds})

    # Save to public/private
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    train_path = public / "train.csv"
    test_path = public / "test.csv"
    sample_path = public / "sample_submission.csv"
    ans_path = private / "test_answer.csv"

    train_out.to_csv(train_path, index=False)
    test_out.to_csv(test_path, index=False)
    sample_sub.to_csv(sample_path, index=False)
    test_answer.to_csv(ans_path, index=False)

    # Copy description.txt into public if exists at project root
    project_root = public.parent
    desc_src = project_root / "description.txt"
    if desc_src.exists():
        (public / "description.txt").write_text(desc_src.read_text(encoding="utf-8"), encoding="utf-8")

    # Validations
    train_chk = pd.read_csv(train_path)
    test_chk = pd.read_csv(test_path)
    samp_chk = pd.read_csv(sample_path)
    ans_chk = pd.read_csv(ans_path)

    assert train_chk["review_id"].is_unique and test_chk["review_id"].is_unique and ans_chk["review_id"].is_unique
    assert set(test_chk["review_id"]) == set(ans_chk["review_id"]) == set(samp_chk["review_id"]) \
        , "ID sets must match across test, answers, and sample"
    assert test_chk.shape[1] == train_chk.shape[1] - 1, "Test should have one fewer column (no rating)"
    tr_labels = set(train_chk["rating"].astype(int).unique().tolist())
    assert tr_labels.issuperset({1, 2, 3, 4, 5}), "Train must include all 1..5 classes"
    ts_labels = set(ans_chk["rating"].astype(int).unique().tolist())
    assert ts_labels.issubset(tr_labels), "All test labels must occur in training set"

    return train_out, test_out, test_answer


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete data preparation process.

    - Reads raw CSVs from raw/ (reviews_*.csv and product_info.csv)
    - Merges product info into reviews
    - Deterministically splits into train/test by time order (80/20)
    - Writes:
        public/train.csv, public/test.csv, public/sample_submission.csv
        private/test_answer.csv
      and copies description.txt into public/ if present.
    """
    # Load and concat reviews
    review_paths = [raw / f for f in REVIEWS_FILES]
    for p in review_paths:
        if not p.exists():
            raise FileNotFoundError(f"Missing reviews file: {p}")
    dfs = [_read_reviews_csv(p) for p in review_paths]
    reviews = pd.concat(dfs, axis=0, ignore_index=True)

    # Keep only rows with valid rating and basic cleaning
    reviews = reviews[~reviews["rating"].isna()].copy()
    reviews["rating"] = pd.to_numeric(reviews["rating"], errors="coerce")
    reviews = reviews[~reviews["rating"].isna()].copy()
    reviews["rating"] = reviews["rating"].clip(1, 5)

    # Merge products
    prod_path = raw / PRODUCTS_FILE
    if not prod_path.exists():
        raise FileNotFoundError(f"Missing products file: {prod_path}")
    prods = _read_products_csv(prod_path)
    merged = reviews.merge(prods, on="product_id", how="left", suffixes=("", "_prod"))

    _split_and_write(merged, public=public, private=private)
