import argparse
import json
import os
from itertools import combinations
from typing import Dict, List, Tuple

import joblib
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler

PPI_RESULTS_DIR = os.path.join(os.path.dirname(__file__), 'ppi_results')
SAVED_MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'saved_models')
COMBINED_LABELED_CSV = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'combined_labeled_results.csv')


def load_artifacts() -> Dict:
    model_path = os.path.join(SAVED_MODELS_DIR, 'feedback_rating_regressor.pkl')
    feat_path = os.path.join(SAVED_MODELS_DIR, 'feature_columns_reg.pkl')
    sent_le_path = os.path.join(SAVED_MODELS_DIR, 'sentiment_label_encoder_reg.pkl')
    task_le_path = os.path.join(SAVED_MODELS_DIR, 'task_type_label_encoder_reg.pkl')

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Regressor model not found at {model_path}")
    model = joblib.load(model_path)

    feature_columns = joblib.load(feat_path) if os.path.exists(feat_path) else None
    if feature_columns is None:
        raise FileNotFoundError(f"Feature columns (reg) not found at {feat_path}")

    sent_le = joblib.load(sent_le_path) if os.path.exists(sent_le_path) else None
    task_le = joblib.load(task_le_path) if os.path.exists(task_le_path) else None
    if sent_le is None or task_le is None:
        raise FileNotFoundError("Saved label encoders (reg) not found")

    return {
        'model': model,
        'feature_columns': feature_columns,
        'sent_le': sent_le,
        'task_le': task_le,
    }


def ensure_columns(df: pd.DataFrame, cols: List[str]):
    for c in cols:
        if c not in df.columns:
            df[c] = np.nan

def unify_feature_column_names(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    rename_map = {
        'git commit': 'git_commit',
        'git push': 'git_push',
        'git pull': 'git_pull',
        'git reset': 'git_reset',
        'git rebase': 'git_rebase',
    }
    cols_present = {c for c in df.columns}
    to_rename = {k: v for k, v in rename_map.items() if k in cols_present}
    if to_rename:
        df.rename(columns=to_rename, inplace=True)
    return df


def encode_with_fallback(series: pd.Series, le) -> pd.Series:
    s = series.astype(str).fillna('')
    known = set(le.classes_)
    fallback = 'Unknown' if 'Unknown' in known else (le.classes_[0] if len(le.classes_) > 0 else '')
    mask_unseen = ~s.isin(known)
    if mask_unseen.any():
        s.loc[mask_unseen] = fallback
    return pd.Series(le.transform(s), index=series.index)


def preprocess(df: pd.DataFrame, artifacts: Dict, is_labeled: bool, median_ref: Dict = None) -> Tuple[pd.DataFrame, Dict]:
    df = df.copy()

    if 'feedback_task_type' in df.columns:
        df = df[~df['feedback_task_type'].astype(str).str.contains(',', na=False)]

    if 'user_message_count' in df.columns:
        df['message_count'] = df['user_message_count']
    else:
        df['message_count'] = np.nan

    git_cols = [c for c in df.columns if c.startswith('git')]
    for c in git_cols:
        df[c] = pd.to_numeric(df[c], errors='coerce').fillna(0)

    rubric_cols = [
        'misunderstood_intention', 'did_not_follow_instruction', 'insufficient_analysis',
        'insufficient_testing', 'insufficient_debugging', 'incomplete_implementation', 'scope_creep'
    ]
    for c in rubric_cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors='coerce').fillna(0)
        else:
            df[c] = 0.0

    if 'feedback_sentiment_classification' in df.columns:
        df['sentiment_encoded'] = encode_with_fallback(df['feedback_sentiment_classification'], artifacts['sent_le'])
    else:
        df['sentiment_encoded'] = 0

    if 'feedback_task_type' in df.columns:
        df['task_type_encoded'] = encode_with_fallback(df['feedback_task_type'], artifacts['task_le'])
    else:
        df['task_type_encoded'] = 0

    if median_ref is None:
        median_ref = {}
    if 'message_count' not in median_ref:
        median_ref['message_count'] = pd.to_numeric(df['message_count'], errors='coerce').median()
    df['message_count'] = pd.to_numeric(df['message_count'], errors='coerce').fillna(median_ref['message_count'])

    ensure_columns(df, artifacts['feature_columns'])

    X = df[artifacts['feature_columns']].copy()
    for c in X.columns:
        X[c] = pd.to_numeric(X[c], errors='coerce').fillna(0)

    return X, median_ref


def ppi_mean(Y_lab: np.ndarray, f_lab: np.ndarray, f_unlab: np.ndarray) -> Tuple[float, float, float, Dict[str, float]]:
    Y_lab = np.asarray(Y_lab, dtype=float)
    f_lab = np.asarray(f_lab, dtype=float)
    f_unlab = np.asarray(f_unlab, dtype=float)

    n = len(Y_lab)
    N = len(f_unlab)
    if n == 0:
        raise ValueError("No labeled data provided for PPI")

    Y_bar = Y_lab.mean()
    f_bar_lab = f_lab.mean() if len(f_lab) > 0 else 0.0

    if N > 0:
        f_bar_unlab = f_unlab.mean()
        if n > 1:
            cov_Y_f = float(np.cov(Y_lab, f_lab, ddof=1)[0, 1])
        else:
            cov_Y_f = 0.0
        var_f_all = float(np.var(np.concatenate([f_lab, f_unlab]), ddof=1)) if (len(f_lab) + len(f_unlab)) > 1 else 0.0
        denom = (1.0 + (n / float(max(N, 1)))) * var_f_all
        lambda_hat = (cov_Y_f / denom) if denom > 1e-12 else 0.0
        mu_hat = Y_bar + lambda_hat * (f_bar_unlab - f_bar_lab)

        var_Y = float(np.var(Y_lab, ddof=1)) if n > 1 else 0.0
        sigma2_ppi = var_Y - 2 * lambda_hat * cov_Y_f + (lambda_hat ** 2) * (1.0 + n / float(max(N, 1))) * var_f_all
        sigma2_ppi = max(sigma2_ppi, 1e-12)
        se_ppi = float(np.sqrt(sigma2_ppi / float(n)))
    else:
        mu_hat = Y_bar
        var_Y = float(np.var(Y_lab, ddof=1)) if n > 1 else 0.0
        se_ppi = float(np.sqrt(var_Y / float(max(n, 1))))
        lambda_hat = 0.0
        f_bar_unlab = f_bar_lab
        sigma2_ppi = var_Y

    se_standard = float(np.sqrt((np.var(Y_lab, ddof=1) if n > 1 else 0.0) / float(max(n, 1))))
    efficiency_gain = (se_standard ** 2) / (se_ppi ** 2) if se_ppi > 0 else 1.0

    extras = {
        'f_bar_lab': float(f_bar_lab),
        'f_bar_unlab': float(f_bar_unlab),
        'se_standard': float(se_standard),
        'sigma2_ppi': float(sigma2_ppi),
        'efficiency_gain': float(efficiency_gain),
    }
    return float(mu_hat), float(se_ppi), float(lambda_hat), extras


def mean_ci(mu: float, se: float, z: float = 1.96) -> Tuple[float, float]:
    return mu - z * se, mu + z * se


def detect_variants(df_lab: pd.DataFrame, df_unlab: pd.DataFrame) -> Tuple[Dict[str, Dict[str, List[str]]], str]:
    # Prefer 'llm_model' in unlabeled; otherwise 'variant'
    if 'llm_model' in df_unlab.columns:
        unlab_col = 'llm_model'
        variant_map = {
            'claude-3-7': {
                'labeled_contains': ['claude-3-7'],
                'unlabeled_contains': ['claude-3-7'],
            },
            'claude-sonnet-4': {
                'labeled_contains': ['claude-sonnet-4'],
                'unlabeled_contains': ['claude-sonnet-4'],
            },
        }
        return variant_map, unlab_col

    if 'variant' in df_unlab.columns:
        unlab_col = 'variant'
        unlab_vals = set(str(v).lower() for v in df_unlab['variant'].dropna().unique())
        # Planning vs Unplanned
        if {'planning', 'unplanned'}.issubset(unlab_vals):
            variant_map = {
                'planning': {
                    'labeled_contains': ['planning'],
                    'unlabeled_contains': ['planning'],
                },
                'unplanned': {
                    'labeled_contains': ['unplanned'],
                    'unlabeled_contains': ['unplanned'],
                },
            }
            return variant_map, unlab_col
        # Control vs GPT-5
        if {'control', 'gpt5'}.issubset(unlab_vals) or {'control', 'gpt-5'}.issubset(unlab_vals):
            variant_map = {
                'control': {
                    'labeled_contains': ['control', 'claude-sonnet-4', 'claude-4', 'sonnet-4'],
                    'unlabeled_contains': ['control'],
                },
                'gpt5': {
                    'labeled_contains': ['gpt5', 'gpt-5', 'treatment'],
                    'unlabeled_contains': ['gpt5', 'gpt-5'],
                },
            }
            return variant_map, unlab_col
        # Control vs Treatment
        if {'control', 'treatment'}.issubset(unlab_vals) or 'treatment' in unlab_vals:
            variant_map = {
                'control': {
                    'labeled_contains': ['control', 'claude-sonnet-4', 'claude-4', 'sonnet-4'],
                    'unlabeled_contains': ['control'],
                },
                'treatment': {
                    'labeled_contains': ['treatment', 'gpt5', 'gpt-5'],
                    'unlabeled_contains': ['treatment'],
                },
            }
            return variant_map, unlab_col
        # Fallback: pass through whatever strings appear, using them in both labeled/unlabeled matching
        variant_map = {}
        for v in sorted(unlab_vals):
            variant_map[v] = {
                'labeled_contains': [v],
                'unlabeled_contains': [v],
            }
        return variant_map, unlab_col

    return {}, ''


def subset_by_contains(df: pd.DataFrame, column: str, substrings: List[str]) -> pd.DataFrame:
    if column not in df.columns:
        return df.copy()
    col = df[column].astype(str)
    mask = pd.Series(False, index=df.index)
    for sub in substrings:
        mask = mask | col.str.contains(sub, na=False)
    return df[mask]


def compute_group_estimates(df_lab_v: pd.DataFrame, df_unlab_v: pd.DataFrame, artifacts: Dict, conf_method: str = 'knn', conf_q: float = 0.5, knn_k: int = 5, min_unlabeled: int = 100, use_confident_unlabeled: bool = False, std_downsample: bool = False, std_n_per_arm: int = 100, std_seed: int = 12345) -> Dict:
    model = artifacts['model']

    df_lab_v = df_lab_v.copy()
    if 'average_rating' not in df_lab_v.columns:
        raise ValueError("'average_rating' missing in labeled data")
    df_lab_v = df_lab_v[~pd.isna(df_lab_v['average_rating'])]
    df_lab_v['average_rating'] = pd.to_numeric(df_lab_v['average_rating'], errors='coerce')
    df_lab_v = df_lab_v[~pd.isna(df_lab_v['average_rating'])]
    # Keep full labeled set for PPI and optionally create a downsampled view for Standard
    df_lab_full = df_lab_v.copy()
    df_lab_std = df_lab_full
    if std_downsample:
        rng = np.random.default_rng(std_seed)
        n_current = len(df_lab_full)
        if n_current > std_n_per_arm:
            keep_idx = rng.choice(n_current, size=std_n_per_arm, replace=False)
            df_lab_std = df_lab_full.iloc[np.sort(keep_idx)]


    n_lab = len(df_lab_v)
    n_unlab = len(df_unlab_v)
    if n_lab == 0:
        raise ValueError("No labeled rows after filtering")

    X_lab, median_ref = preprocess(df_lab_v, artifacts, is_labeled=True)
    X_unlab, _ = preprocess(df_unlab_v, artifacts, is_labeled=False, median_ref=median_ref)

    print(X_lab.shape, X_unlab.shape)

    f_lab = model.predict(X_lab)
    f_unlab = model.predict(X_unlab) if len(X_unlab) > 0 else np.array([])
    # Optionally filter unlabeled points to only "confident" ones based on proxy prediction margin
    if use_confident_unlabeled and len(f_unlab) > 0:
        # Define confidence as distance from mid-scale (e.g., 3.0 on a 1-5 scale). Adjust if your label scale differs.
        mid = 3.0
        margin = np.abs(f_unlab - mid)
        # Heuristic threshold: top 50% most confident by default; could be parameterized later
        thresh = np.nanmedian(margin)
        keep_idx = np.where(margin >= thresh)[0]
        if len(keep_idx) < len(f_unlab):
            print(f"Filtering confident unlabeled: kept {len(keep_idx)}/{len(f_unlab)} (~{len(keep_idx)/max(1,len(f_unlab))*100:.1f}%)")
        f_unlab = f_unlab[keep_idx]
        n_unlab = int(len(f_unlab))

    corr = float(np.corrcoef(df_lab_v['average_rating'].values, f_lab)[0, 1]) if n_lab > 1 else 0.0
    if not np.isfinite(corr):
        corr = 0.0

    mu_ppi, se_ppi, lambda_hat, extras = ppi_mean(df_lab_v['average_rating'].values, f_lab, f_unlab)
    n_std = len(df_lab_std)
    mu_std = float(df_lab_std['average_rating'].mean())
    se_std = float(np.sqrt(df_lab_std['average_rating'].var(ddof=1) / n_std)) if n_std > 1 else 0.0

    return {
        'n_labeled': int(n_lab),
        'n_unlabeled': int(n_unlab),
        'mu_standard': mu_std,
        'se_standard': se_std,
        'mu_ppi': float(mu_ppi),
        'se_ppi': float(se_ppi),
        'lambda_hat': float(lambda_hat),
        'corr_y_f': float(corr),
        'f_labeled_mean': float(extras.get('f_bar_lab', float('nan'))),
        'f_unlabeled_mean': float(extras.get('f_bar_unlab', float('nan'))),
        'efficiency_gain': float(extras.get('efficiency_gain', float('nan'))),
    }


def choose_pair_order(a: str, b: str) -> Tuple[str, str]:
    a_low = a.lower()
    b_low = b.lower()
    preferred_treat = ['treatment', 'gpt5', 'gpt-5', 'claude-3-7']
    preferred_ctrl = ['control', 'claude-sonnet-4', 'claude-4', 'sonnet-4']
    if any(p in a_low for p in preferred_treat) and any(p in b_low for p in preferred_ctrl):
        return a, b
    if any(p in b_low for p in preferred_treat) and any(p in a_low for p in preferred_ctrl):
        return b, a
    # Fallback deterministic ordering
    return tuple(sorted([a, b]))


def standardize_labeled(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    # Unify ID and rating column names
    if 'conversation_id' not in df.columns and 'convo_id' in df.columns:
        df.rename(columns={'convo_id': 'conversation_id'}, inplace=True)
    if 'average_rating' not in df.columns and 'avg_rating' in df.columns:
        df.rename(columns={'avg_rating': 'average_rating'}, inplace=True)
    return df



def enrich_labeled_with_combined(df_lab: pd.DataFrame) -> pd.DataFrame:
    """Augment labeled rows with feature columns by looking them up in
    combined_labeled_results.csv using conversation_id.
    Keeps the labeled file's core columns (conversation_id, average_rating,
    created_at, variant, rating_count) and adds the rest when available.
    """
    try:
        df_comb = pd.read_csv(COMBINED_LABELED_CSV)
    except Exception as e:
        print(f"Warning: could not read {COMBINED_LABELED_CSV}: {e}; proceeding without enrichment")
        return df_lab

    df_comb = standardize_labeled(df_comb)
    df_comb = unify_feature_column_names(df_comb)

    # Drop duplicate conversation_ids (keep the latest by created_at when available)
    if 'conversation_id' in df_comb.columns:
        if 'created_at' in df_comb.columns:
            df_comb = df_comb.sort_values('created_at')
        df_comb = df_comb.drop_duplicates('conversation_id', keep='last')

    core_cols = {'conversation_id', 'average_rating', 'created_at', 'variant', 'rating_count'}
    # Only add columns that don't already exist in df_lab to avoid collisions
    add_cols = [c for c in df_comb.columns if c not in core_cols and c not in df_lab.columns]

    # Only keep rows for conversations present in labeled file to reduce merge size
    if 'conversation_id' in df_lab.columns and 'conversation_id' in df_comb.columns:
        df_comb_small = df_comb[['conversation_id'] + add_cols]
        merged = df_lab.merge(df_comb_small, on='conversation_id', how='left', indicator=True)
        matched = int((merged['_merge'] == 'both').sum())
        print(f"Feature enrichment: matched {matched}/{len(df_lab)} labeled rows by conversation_id; added {len(add_cols)} feature cols")
        return merged.drop(columns=['_merge'])
    else:
        # If labeled file lacks conversation_id (unlikely), return as-is
        print("Warning: 'conversation_id' missing in labeled or combined data; skipping enrichment")
        return df_lab


def main():
    parser = argparse.ArgumentParser(description='Estimate effect size (difference in means) using PPI')
    parser.add_argument('--dataset', choices=['claude4_vs_gpt5','claude37_vs_claude4','condenser','planning'], default='claude4_vs_gpt5', help='Paired labeled/unlabeled set to use from labeled_results/ and unlabeled_results/')
    parser.add_argument('--labeled_csv', default=None, help='Override path to labeled CSV')
    parser.add_argument('--unlabeled_csv', default=None, help='Override path to unlabeled CSV')
    parser.add_argument('--confident-unlabeled', dest='confident_unlabeled', action='store_true',
                        help='Use only confident unlabeled points (top 50% by |prediction - mid|)')
    parser.add_argument('--std-downsample', dest='std_downsample', action='store_true',
                        help='Downsample Standard to a fixed n per arm; PPI remains on full labeled')
    parser.add_argument('--std-n-per-arm', dest='std_n_per_arm', type=int, default=100,
                        help='Target labeled samples per arm for Standard when downsampling')
    parser.add_argument('--std-seed', dest='std_seed', type=int, default=12345,
                        help='Random seed for Standard downsampling')
    args = parser.parse_args()

    # Choose paired files
    base_dir = os.path.dirname(os.path.dirname(__file__))
    labeled_dir = os.path.join(base_dir, 'labeled_results')
    unlabeled_dir = os.path.join(base_dir, 'unlabeled_results')
    dataset_to_files = {
        'claude4_vs_gpt5': (
            os.path.join(labeled_dir, 'claude4_vs_gpt5_labeled.csv'),
            os.path.join(unlabeled_dir, 'combined_unlabeled_claude4_vs_gpt5.csv'),
        ),
        'claude37_vs_claude4': (
            os.path.join(labeled_dir, 'claude37_vs_claude4_labeled.csv'),
            os.path.join(unlabeled_dir, 'combined_unlabeled_claude37_vs_claude4.csv'),
        ),
        'condenser': (
            os.path.join(labeled_dir, 'condenser_120_vs_80_labeled.csv'),
            os.path.join(unlabeled_dir, 'combined_unlabeled_memory.csv'),
        ),
        'planning': (
            os.path.join(labeled_dir, 'planning_vs_unplanned_labeled.csv'),
            os.path.join(unlabeled_dir, 'planning_vs_unplanned_unlabeled.csv'),
        ),
    }

    labeled_csv, unlabeled_csv = dataset_to_files[args.dataset]
    if args.labeled_csv:
        labeled_csv = args.labeled_csv
    if args.unlabeled_csv:
        unlabeled_csv = args.unlabeled_csv

    os.makedirs(PPI_RESULTS_DIR, exist_ok=True)

    artifacts = load_artifacts()

    df_lab = pd.read_csv(labeled_csv)
    df_lab = standardize_labeled(df_lab)
    df_lab = enrich_labeled_with_combined(df_lab)
    df_lab = unify_feature_column_names(df_lab)
    df_unlab = pd.read_csv(unlabeled_csv)
    df_unlab = unify_feature_column_names(df_unlab)

    variant_map, unlab_col = detect_variants(df_lab, df_unlab)
    if not variant_map:
        print("Could not detect variants; exiting.")
        return


    # Build per-variant labeled/unlabeled subsets and compute estimates
    per_variant = {}
    for variant_name, patt in variant_map.items():
        if 'variant' not in df_lab.columns:
            df_lab_v = df_lab.copy()
        else:
            df_lab_v = subset_by_contains(df_lab, 'variant', patt['labeled_contains'])
        if unlab_col not in df_unlab.columns:
            df_unlab_v = df_unlab.copy()
        else:
            df_unlab_v = subset_by_contains(df_unlab, unlab_col, patt['unlabeled_contains'])

        if len(df_lab_v) == 0:
            print(f"Variant {variant_name}: no labeled rows; skipping")
            continue

        try:
            stats = compute_group_estimates(
                df_lab_v,
                df_unlab_v,
                artifacts,
                use_confident_unlabeled=args.confident_unlabeled,
                std_downsample=args.std_downsample,
                std_n_per_arm=args.std_n_per_arm,
                std_seed=args.std_seed,
            )
            per_variant[variant_name] = stats
        except Exception as e:
            print(f"Variant {variant_name}: failed to compute stats: {e}")

    if len(per_variant) < 2:
        print("Fewer than two variants with labeled data; nothing to compare.")
        return

    # Create pairwise effect-size results
    variants = list(per_variant.keys())
    results = []
    for a_raw, b_raw in combinations(variants, 2):
        a, b = choose_pair_order(a_raw, b_raw)
        sa, sb = per_variant[a], per_variant[b]

        diff_std = sa['mu_standard'] - sb['mu_standard']
        se_diff_std = float(np.sqrt(sa['se_standard']**2 + sb['se_standard']**2))
        ci_std_low, ci_std_high = mean_ci(diff_std, se_diff_std)

        diff_ppi = sa['mu_ppi'] - sb['mu_ppi']
        se_diff_ppi = float(np.sqrt(sa['se_ppi']**2 + sb['se_ppi']**2))
        ci_ppi_low, ci_ppi_high = mean_ci(diff_ppi, se_diff_ppi)

        results.append({
            'variant_a': a,
            'variant_b': b,
            'n_labeled_a': sa['n_labeled'],
            'n_labeled_b': sb['n_labeled'],
            'n_unlabeled_a': sa['n_unlabeled'],
            'n_unlabeled_b': sb['n_unlabeled'],
            'diff_standard': float(diff_std),
            'se_diff_standard': float(se_diff_std),
            'ci95_standard_low': float(ci_std_low),
            'ci95_standard_high': float(ci_std_high),
            'diff_ppi': float(diff_ppi),
            'se_diff_ppi': float(se_diff_ppi),
            'ci95_ppi_low': float(ci_ppi_low),
            'ci95_ppi_high': float(ci_ppi_high),
            'lambda_a': sa['lambda_hat'],
            'lambda_b': sb['lambda_hat'],
            'corr_y_f_a': sa['corr_y_f'],
            'corr_y_f_b': sb['corr_y_f'],
            'efficiency_gain_a': sa['efficiency_gain'],
            'efficiency_gain_b': sb['efficiency_gain'],
        })

    out_df = pd.DataFrame(results)
    suffix = ''
    if args.std_downsample:
        suffix += f"_stdN{args.std_n_per_arm}"
    csv_path = os.path.join(PPI_RESULTS_DIR, f'effect_size_ppi_estimates_{args.dataset}{suffix}.csv')
    json_path = os.path.join(PPI_RESULTS_DIR, f'effect_size_ppi_estimates_{args.dataset}{suffix}.json')
    out_df.to_csv(csv_path, index=False)
    with open(json_path, 'w') as f:
        json.dump(results, f, indent=2)

    print("Effect size estimates (A - B):")
    if len(out_df) > 0:
        disp_cols = [
            'variant_a','variant_b','n_labeled_a','n_labeled_b','n_unlabeled_a','n_unlabeled_b',
            'diff_standard','se_diff_standard','diff_ppi','se_diff_ppi','ci95_ppi_low','ci95_ppi_high'
        ]
        print(out_df[disp_cols].to_string(index=False, float_format=lambda x: f"{x:.4f}"))
    else:
        print("No results computed (check data and filters)")

    print(f"\nSaved CSV to: {csv_path}")
    print(f"Saved JSON to: {json_path}")


if __name__ == '__main__':
    main()
