import argparse
import math
from pathlib import Path
from typing import List, Dict, Any, Tuple

import pandas as pd
import numpy as np
try:
    from scipy.stats import pearsonr, spearmanr
except Exception:
    pearsonr = None
    spearmanr = None


IGNORE_COLS = {"conversation_id", "created_at"}
TARGET_COL = "average_rating"
RATING_COUNT_COL = "rating_count"


def _is_binary(series: pd.Series) -> bool:
    vals = series.dropna().unique()
    if len(vals) == 0:
        return False
    return set(pd.Series(vals).astype(float).dropna().unique()).issubset({0.0, 1.0})


def _coerce_numeric(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
    num_cols: List[str] = []
    out = df.copy()
    for c in df.columns:
        if c in IGNORE_COLS:
            continue
        s = pd.to_numeric(df[c], errors="coerce")
        # consider numeric if most values are numeric or clearly count-like
        non_nan_ratio = 1.0 - (s.isna().sum() / len(s))
        if c.endswith("_count") or c.endswith("_len_chars") or c.endswith("_len_words") or c.endswith("segment_count"):
            out[c] = s
            num_cols.append(c)
        elif non_nan_ratio >= 0.9:
            out[c] = s
            num_cols.append(c)
    return out, num_cols


def _categorical_columns(df: pd.DataFrame, exclude: List[str]) -> List[str]:
    cats = []
    for c in df.columns:
        if c in exclude or c in IGNORE_COLS or c == TARGET_COL:
            continue
        if df[c].dtype == object:
            cats.append(c)
        else:
            # also treat as categorical if very few unique values relative to size and not numeric-converted
            nunique = df[c].nunique(dropna=True)
            if 1 < nunique <= 12:
                cats.append(c)
    return cats


def _safe_corr(x: pd.Series, y: pd.Series, method: str = "pearson") -> Tuple[float, float, int]:
    m = pd.concat([x, y], axis=1).dropna()
    if len(m) < 3:
        return float("nan"), float("nan"), len(m)
    a = m.iloc[:, 0].astype(float)
    b = m.iloc[:, 1].astype(float)
    if method == "pearson":
        if pearsonr is not None:
            try:
                r, p = pearsonr(a, b)
                return float(r), float(p), len(m)
            except Exception:
                pass
        # fallback without p-value
        try:
            r = float(pd.Series(a).corr(pd.Series(b), method="pearson"))
            return r, float("nan"), len(m)
        except Exception:
            return float("nan"), float("nan"), len(m)
    else:
        if spearmanr is not None:
            try:
                rho, p = spearmanr(a, b)
                return float(rho), float(p), len(m)
            except Exception:
                pass
        # fallback: rank transform then Pearson
        try:
            ra = pd.Series(a).rank(method="average")
            rb = pd.Series(b).rank(method="average")
            r = float(ra.corr(rb, method="pearson"))
            return r, float("nan"), len(m)
        except Exception:
            return float("nan"), float("nan"), len(m)


def correlate_numeric(df: pd.DataFrame, num_cols: List[str], target: str = TARGET_COL) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    y = df[target]
    for c in num_cols:
        if c == target or c in IGNORE_COLS:
            continue
        r_p, p_p, n_p = _safe_corr(df[c], y, method="pearson")
        r_s, p_s, n_s = _safe_corr(df[c], y, method="spearman")
        s = df[c]
        try:
            is_bin = _is_binary(s)
        except Exception:
            is_bin = False
        row: Dict[str, Any] = {
            "feature": c,
            "n": int(max(n_p, n_s)),
            "pearson_r": r_p,
            "pearson_p": p_p,
            "spearman_rho": r_s,
            "spearman_p": p_s,
            "mean": float(pd.to_numeric(s, errors="coerce").mean(skipna=True)),
            "std": float(pd.to_numeric(s, errors="coerce").std(skipna=True)),
            "min": float(pd.to_numeric(s, errors="coerce").min(skipna=True)),
            "max": float(pd.to_numeric(s, errors="coerce").max(skipna=True)),
            "is_binary": bool(is_bin),
        }
        if is_bin:
            try:
                m = pd.concat([s, y], axis=1).dropna()
                g1 = m[m[c] == 1][target].mean()
                g0 = m[m[c] == 0][target].mean()
                cnt1 = (m[c] == 1).sum()
                cnt0 = (m[c] == 0).sum()
                row.update({
                    "mean_rating_if_1": float(g1) if not math.isnan(g1) else float("nan"),
                    "mean_rating_if_0": float(g0) if not math.isnan(g0) else float("nan"),
                    "n_1": int(cnt1),
                    "n_0": int(cnt0),
                    "diff_1_minus_0": float(g1 - g0) if (g1 == g1 and g0 == g0) else float("nan"),
                })
            except Exception:
                pass
        rows.append(row)
    out = pd.DataFrame(rows)
    out = out.sort_values(by=["pearson_r"], ascending=False)
    return out


def correlate_categorical(df: pd.DataFrame, cat_cols: List[str], target: str = TARGET_COL, min_count: int = 20) -> Dict[str, pd.DataFrame]:
    y = df[target]
    results: Dict[str, pd.DataFrame] = {}
    for c in cat_cols:
        s = df[c].astype(str).replace({"nan": np.nan})
        vc = s.value_counts(dropna=True)
        keep = vc[vc >= min_count].index
        rows: List[Dict[str, Any]] = []
        for v in keep:
            ind = (s == v).astype(float)
            r_p, p_p, n_p = _safe_corr(ind, y, method="pearson")
            r_s, p_s, n_s = _safe_corr(ind, y, method="spearman")
            m = pd.concat([ind, y], axis=1).dropna()
            if len(m) == 0:
                continue
            r1 = m[m[c] == 1][target].mean() if c in m.columns else m[m.iloc[:,0] == 1][target].mean()
            r0 = m[m.iloc[:,0] == 0][target].mean()
            rows.append({
                "category": str(v),
                "count": int(vc.loc[v]),
                "pearson_r": r_p,
                "pearson_p": p_p,
                "spearman_rho": r_s,
                "spearman_p": p_s,
                "mean_rating_if_present": float(r1) if r1 == r1 else float("nan"),
                "mean_rating_if_absent": float(r0) if r0 == r0 else float("nan"),
                "diff_present_minus_absent": float((r1 - r0)) if (r1 == r1 and r0 == r0) else float("nan"),
            })
        if rows:
            dfc = pd.DataFrame(rows).sort_values(by=["pearson_r"], ascending=False)
            results[c] = dfc
    return results


def run(input_path: str, output_dir: str, min_count: int = 20) -> None:
    df = pd.read_csv(input_path)
    if TARGET_COL not in df.columns:
        raise ValueError(f"Target column '{TARGET_COL}' not found in {input_path}")
    out_dir = Path(output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    df_num, num_cols = _coerce_numeric(df)

    # Numeric correlations
    num_corr = correlate_numeric(df_num, num_cols, target=TARGET_COL)
    num_out = out_dir / "correlations_numeric.csv"
    num_corr.to_csv(num_out, index=False)

    # Quick binary/flag summary (subset)
    flags = num_corr[num_corr["is_binary"] == True]
    flags.sort_values(by=["pearson_r"], ascending=False).to_csv(out_dir / "correlations_flags.csv", index=False)

    # Categorical correlations via one-vs-rest indicators
    cat_cols = _categorical_columns(df, exclude=num_cols + [TARGET_COL])
    cat_results = correlate_categorical(df, cat_cols, target=TARGET_COL, min_count=min_count)
    for c, dfx in cat_results.items():
        safe_c = c.replace("/", "_")
        dfx.to_csv(out_dir / f"correlations_categories_{safe_c}.csv", index=False)

    # Console summary
    top = num_corr.sort_values(by=["pearson_r"], ascending=False).head(15)
    print("Top numeric features by Pearson correlation with rating:")
    print(top[["feature", "pearson_r", "spearman_rho", "n"]].to_string(index=False))

    if cat_results:
        print("\nTop categorical signals (by column):")
        for c, dfx in cat_results.items():
            head = dfx.head(5)
            print(f"- {c}:")
            print(head[["category", "count", "pearson_r", "diff_present_minus_absent"]].to_string(index=False))


def main():
    ap = argparse.ArgumentParser(description="Preliminary correlation scan of features vs average_rating")
    ap.add_argument("--input", default="present_conversations_aggregated.csv", help="Path to present_conversations_aggregated.csv")
    ap.add_argument("--output_dir", default="unified_analysis/outputs", help="Directory to write correlation CSVs")
    ap.add_argument("--min_count", type=int, default=20, help="Min count for categorical levels to be evaluated")
    args = ap.parse_args()

    run(args.input, args.output_dir, min_count=args.min_count)


if __name__ == "__main__":
    main()
