import csv
import re
import math
from dataclasses import dataclass
from typing import Dict, List, Tuple, Any
from datetime import datetime

# Column name helpers
FIRST_PREFIX = "jsonl_first_"
AGG_PREFIX = "jsonl_agg_"
FLAG_TOKEN = "flag_"

NUMERIC_SUFFIXES = {
    "user_goal_len_chars",  # first
    "user_goal_len_words",
    "user_goal_last_len_chars",  # agg
    "user_goal_last_len_words",
    "segment_count",
}

@dataclass
class Dataset:
    X: List[Dict[str, Any]]
    y: List[float]
    feature_names: List[str]
    w: List[float] | None = None  # optional sample weights aligned with y


def read_csv(path: str) -> List[Dict[str, str]]:
    with open(path, newline="", encoding="utf-8") as f:
        rdr = csv.DictReader(f)
        return [row for row in rdr]


def parse_datetime(s: str) -> datetime | None:
    s = (s or "").strip()
    if not s:
        return None
    # try ISO-like first
    for fmt in ["%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d"]:
        try:
            return datetime.strptime(s, fmt)
        except Exception:
            pass
    # fallback: extract numbers
    m = re.match(r"(\d{4}-\d{2}-\d{2})", s)
    if m:
        try:
            return datetime.strptime(m.group(1), "%Y-%m-%d")
        except Exception:
            return None
    return None


def _is_flag(col: str) -> bool:
    return col.startswith(FIRST_PREFIX + FLAG_TOKEN) or col.startswith(AGG_PREFIX + FLAG_TOKEN)


def _is_first_static(col: str) -> bool:
    return col.startswith(FIRST_PREFIX) and not _is_flag(col)


def _is_agg_static(col: str) -> bool:
    return col.startswith(AGG_PREFIX) and not _is_flag(col)


def _is_numeric_candidate(col: str) -> bool:
    # conservative numeric selector based on suffixes or explicit count columns
    if col.endswith("_count"):
        return True
    for suf in NUMERIC_SUFFIXES:
        if col.endswith(suf):
            return True
    return False


def _weight_from_row(r: Dict[str, str]) -> float:
    rc = (r.get("rating_count") or "").strip()
    try:
        n = float(rc)
        if n <= 0:
            return 1.0
        # dampen large counts to avoid overweighting a few items
        return float(min(10.0, math.sqrt(n)))
    except Exception:
        return 1.0


def build_feature_sets(rows: List[Dict[str, str]], target: str = "average_rating") -> Dict[str, Dataset]:
    # collect columns
    if not rows:
        return {}
    cols = list(rows[0].keys())
    # Identify feature groups
    first_static = [c for c in cols if _is_first_static(c)]
    first_flags = [c for c in cols if c.startswith(FIRST_PREFIX + FLAG_TOKEN)]
    agg_static = [c for c in cols if _is_agg_static(c)]
    agg_flags = [c for c in cols if c.startswith(AGG_PREFIX + FLAG_TOKEN)]

    numeric_only = [c for c in cols if _is_numeric_candidate(c)]

    groups: Dict[str, List[str]] = {
        "first_static": first_static,
        "first_flags": first_flags,
        "agg_static": agg_static,
        "agg_flags": agg_flags,
        "numeric_only": numeric_only,
        "all": [c for c in cols if c not in ("conversation_id", target, "rating_count", "created_at")],
    }


    # Custom feature set: git actions + user_message_count + feedback sentiment/task
    git_feedback_cols = [
        c for c in [
            "git_commit",
            "git_push",
            "git_pull",
            "git_reset",
            "git_rebase",
            "user_message_count",
            "feedback_sentiment_classification",
            "feedback_task_type",
        ] if c in cols
    ]
    if git_feedback_cols:
        groups["git_user_feedback"] = git_feedback_cols

    datasets: Dict[str, Dataset] = {}
    for name, fcols in groups.items():
        X: List[Dict[str, Any]] = []
        y: List[float] = []
        w: List[float] = []
        for r in rows:
            t = (r.get(target) or "").strip()
            if not t:
                continue
            try:
                ty = float(t)
            except Exception:
                continue
            # build feature dict
            fx: Dict[str, Any] = {}
            for c in fcols:
                v = r.get(c)
                if v is None:
                    continue
                v = v.strip()
                # cast obvious ints
                if v.isdigit():
                    fx[c] = int(v)
                    continue
                # cast counts ending with _count
                if c.endswith("_count"):
                    try:
                        fx[c] = float(v)
                        continue
                    except Exception:
                        pass
                # last resort: try float
                try:
                    fx[c] = float(v)
                except Exception:
                    fx[c] = v
            X.append(fx)
            y.append(ty)
            w.append(_weight_from_row(r))
        datasets[name] = Dataset(X=X, y=y, feature_names=fcols, w=w)
    return datasets
