import pandas as pd
import json
import numpy as np
from pathlib import Path
import re
from functools import partial
import os
import random
import warnings
from typing import Optional
from datasets import load_dataset
import csv


MODULE_DIR = Path(__file__).resolve().parent
REPO_ROOT = MODULE_DIR.parent
DATA_DIR = REPO_ROOT / "data"
SCORE_DIR = DATA_DIR / "score"
MISC_DIR = DATA_DIR / "miscellaneous"
BINARY_DIR = DATA_DIR / "binary"
INJECTED_BIAS_DIR = MISC_DIR / "injected_bias"


def load_feedbackqa(path=SCORE_DIR / "feedbackqa.csv"):
    """Load the preprocessed FeedbackQA annotations from CSV."""
    return pd.read_csv(path)


def load_helpsteer2(path=SCORE_DIR / "helpsteer2.csv"):
    """Load the preprocessed HelpSteer2 validation annotations from CSV."""
    return pd.read_csv(path)

def load_helpsteer3(path=SCORE_DIR / "helpsteer3_sampled.csv"):
    ratings = pd.read_csv(path)
    ratings = ratings.rename(columns={"last_user_conversation": "question"})
    return ratings

def load_tulu3_8b(path=MISC_DIR / "tulu3" / "tulu3_8b_sampled.csv"):
    ratings = pd.read_csv(path)
    ratings = ratings.rename(columns={"prompt": "question"})
    return ratings

def load_hhrlhf(path=MISC_DIR / "anthropic_hh_rlhf" / "hh_rlhf_sampled.csv"):
    ratings = pd.read_csv(path)
    return ratings

def load_injected_bias(
    raw_path=INJECTED_BIAS_DIR / "raw.json",
    perturbed_path=INJECTED_BIAS_DIR / "rich_content.json",
):

    raw = pd.read_json(raw_path, lines=True) 
    rich_content = pd.read_json(perturbed_path, lines=True) 

    # Flatten raw and rich 
    raw_rows = []
    for _, row in raw.iterrows():
        q = row['question']
        for ans in row['answers'].values():
            raw_rows.append({
                'question': q,
                'answer_id': ans['answer_id'],
                'raw_answer': ans['answer']
            })
    raw_flat = pd.DataFrame(raw_rows)

    rich_rows = []
    for _, row in rich_content.iterrows():
        q = row['question']
        for ans in row['answers'].values():
            rich_rows.append({
                'question': q,
                'answer_id': ans['answer_id'],         # full perturbed ID
                'perturb': ans.get('perturb', ''),
                'perturbed_answer': ans['answer']
            })
    rich_flat = pd.DataFrame(rich_rows)

    # Build a lookup of raw IDs
    raw_ids = set(raw_flat['answer_id'])

    # For each rich ID, find which raw ID it's based on
    def find_raw_prefix(pert_id):
        for prefix in raw_ids:
            if pert_id.startswith(prefix):
                return prefix
        return None

    rich_flat['raw_id'] = rich_flat['answer_id'].apply(find_raw_prefix)

    # Merge on (question, raw_id==raw.answer_id)
    merged = raw_flat.merge(
        rich_flat.dropna(subset=['raw_id']),
        left_on=['question','answer_id'],
        right_on=['question','raw_id'],
        how='inner'
    )

    # Keep only actually perturbed rows 
    paired = merged[merged['perturb'] != '']

    # Extract your two aligned DataFrames
    raw_matched = paired[['raw_id', 'question','raw_answer']].rename(columns={'raw_answer':'answer'})
    rich_matched = paired[['raw_id', 'question','perturbed_answer']].rename(columns={'perturbed_answer':'answer'})

    return raw_matched, rich_matched

def load_ultrafeedback(path=SCORE_DIR / "ultrafeedback_sampled.csv"):
    data = pd.read_csv(path)
    return data


def _extract_review5k_prompts(messages):
    """Return (system_prompt, paper_content) from the Review-5K message list.

    The dataset stores the reviewing guidelines in the first element (system
    role) and the submitted paper in the second element (user role). We fall
    back to the first user message we can find if the expected structure is
    missing so downstream code remains robust.
    """
    system_prompt = ''
    paper_content = ''

    if not isinstance(messages, list):
        return system_prompt, paper_content

    if messages:
        first = messages[0]
        if isinstance(first, dict):
            system_prompt = (first.get('content') or '').strip()

    if len(messages) > 1:
        second = messages[1]
        if isinstance(second, dict) and (second.get('role') or '').strip().lower() == 'user':
            paper_content = (second.get('content') or '').strip()

    if not paper_content:
        for msg in messages:
            if not isinstance(msg, dict):
                continue
            if (msg.get('role') or '').strip().lower() == 'user':
                paper_content = (msg.get('content') or '').strip()
                if paper_content:
                    break

    return system_prompt, paper_content


def _build_review5k_cache(cache_path: Path) -> pd.DataFrame:
    records = []
    for split_name in ('train', 'test'):
        ds = load_dataset('WestlakeNLP/Review-5K', split=split_name)
        for idx, row in enumerate(ds):
            system_prompt, paper_content = _extract_review5k_prompts(row.get('messages'))
            rates = row.get('rates') or []
            numeric_rates = [float(r) for r in rates if r is not None]
            avg_rate = float(np.mean(numeric_rates)) if numeric_rates else np.nan
            if not paper_content:
                continue
            records.append({
                'id': row.get('id', f'{split_name}_{idx}'),
                'split': split_name,
                'system_prompt': system_prompt,
                'paper_content': paper_content,
                'avg_rate': avg_rate,
            })
    df = pd.DataFrame(records)
    df.to_csv(
        cache_path,
        index=False,
        quoting=csv.QUOTE_MINIMAL,
        escapechar='\\',
        compression='gzip',
    )
    return df


def load_review5k(split: str = 'all', project_root: Optional[Path] = None):
    project_root = Path(project_root) if project_root is not None else Path(__file__).resolve().parents[1]
    cache_dir = project_root / 'data/score'
    cache_dir.mkdir(parents=True, exist_ok=True)
    cache_path = cache_dir / 'review_5k_minimal.csv.gz'

    if cache_path.exists():
        df = pd.read_csv(cache_path, compression='gzip')
    else:
        df = _build_review5k_cache(cache_path)

    if 'split' not in df.columns:
        df = _build_review5k_cache(cache_path)
    else:
        df['split'] = df['split'].astype(str)

    df['avg_rate'] = pd.to_numeric(df['avg_rate'], errors='coerce')
    df = df.dropna(subset=['avg_rate']).reset_index(drop=True)

    if split and split != 'all':
        df_split = df[df['split'] == split].reset_index(drop=True)
        if df_split.empty:
            df = _build_review5k_cache(cache_path)
            df['split'] = df['split'].astype(str)
            df['avg_rate'] = pd.to_numeric(df['avg_rate'], errors='coerce')
            df = df.dropna(subset=['avg_rate']).reset_index(drop=True)
            df_split = df[df['split'] == split].reset_index(drop=True)
        df = df_split
    else:
        df = df.reset_index(drop=True)
    return df

def prepare_civilcomments(
    n_per: int = 2500,
    thresh: float = 0.5,
    seed: int = 42,
    split: str = "test",
    out_path: Path | str = BINARY_DIR / "civilcomments.csv",
):
    # Load
    ds = load_dataset("civil_comments", split=split)
    df = pd.DataFrame(ds)

    # Binarize
    df["label"] = (df["toxicity"] >= thresh).astype(int)

    # Ensure enough examples per class
    counts = df["label"].value_counts()
    if counts.get(0, 0) < n_per or counts.get(1, 0) < n_per:
        raise ValueError(f"Not enough per class. Have: {counts.to_dict()}")

    # Stratified sample: n_per zeros, n_per ones
    df0 = df[df["label"] == 0].sample(n=n_per, random_state=seed, replace=False)
    df1 = df[df["label"] == 1].sample(n=n_per, random_state=seed, replace=False)

    out = pd.concat([df0, df1]).sample(frac=1, random_state=seed).reset_index(drop=True)

    # Reorder columns: 'label', 'text', then everything else (preserve original order)
    first = ["label", "text"]
    rest = [c for c in df.columns if c not in first]
    out = out[first + rest]

    # Save
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out.to_csv(out_path, index=False)

    return out

    # df_balanced = prepare_civilcomments()


def check_parsed_output(row, min_rating=-3.0, max_rating=3.0):
    """
    Check if a single row's 'parsed_output' value is valid within the specified range.
    Extracts first float from parsed_output and validates it.
    
    Args:
        row (pd.Series): Input row from dataframe
        min_rating (float): Minimum allowed rating value
        max_rating (float): Maximum allowed rating value
        
    Returns:
        bool: Whether the value is valid
    """
    # Check if column exists
    if 'parsed_output' not in row.index:
        return False
    
    # Extract first float from string using regex
    match = re.search(r'-?\d+\.?\d*', str(row['parsed_output']))
    if not match:
        return False
        
    # Convert matched string to numeric
    value = pd.to_numeric(match.group(), errors='coerce')
    
    # Check if value is valid (not NaN and within range)
    is_valid = not pd.isna(value) and min_rating <= value <= max_rating
    
    return is_valid

def extract_score_from_parsed_output(x, min_rating=-3.0, max_rating=3.0):
    # Extract first float from string using regex
    match = re.search(r'-?\d+\.?\d*', str(x))
    if not match:
        return np.nan
        
    # Convert matched string to numeric
    value = pd.to_numeric(match.group(), errors='coerce')
    
    if min_rating <= value <= max_rating:
        return value
    else:
        return np.nan

def collect_judge_outputs(data_dir="../judge_outputs/fully_gaussian/feedbackqa", valid_ratio_threshold=0.7, min_rating=-9.0, max_rating=9.0):
    data_dir = Path(data_dir)
    csv_paths = data_dir.glob("*.csv")
    parsed_outputs = {}

    for p in sorted(csv_paths):
        df = pd.read_csv(p)
        # print(p.stem, p)
        # print(df.head())
        check_parsed_output_in_range = partial(check_parsed_output, min_rating=min_rating, max_rating=max_rating)
        valid_ratio = df.apply(check_parsed_output_in_range, axis=1).mean()
        if valid_ratio > valid_ratio_threshold:
            parsed_outputs[p.stem] = df["parsed_output"]
            print(p.stem, valid_ratio, "Added to parsed_outputs")
        else:
            print(p.stem, valid_ratio, "Skipped")
    judge_df = pd.DataFrame(parsed_outputs)
    return judge_df

def is_valid_score(x, min_rating=1.0, max_rating=4.0):
    """
    Return True iff x is a number (int/float) or a string
    that parses to a float in [min_rating, max_rating].
    """
    try:
        v = float(x)
    except (ValueError, TypeError):
        return False
    return min_rating <= v <= max_rating

def is_valid_binary_score(x):
    """
    Return True iff x is a number (int/float) or a string
    that parses to exactly 0 or 1.
    """
    try:
        v = float(x)
    except (ValueError, TypeError):
        return False
    return v == 0.0 or v == 1.0


def load_judge_dataset_bundle(dataset_name, project_root=None, allow_trim=True, valid_ratio_threshold=0.7):
    """Load judge outputs and matching human evaluation scores for a dataset.

    Parameters
    ----------
    dataset_name : str
        Identifier used by the analysis notebooks (e.g. ``'feedbackqa'`` or
        ``'yelp_with_scores'``). Legacy aliases such as ``'yelp_review_full/test'``
        are also accepted.
    project_root : str or Path, optional
        Root directory of the repository. Defaults to ``src/..``.
    allow_trim : bool, optional
        Whether to trim to the shorter length when the judge outputs and
        human labels disagree. When trimming happens, a warning is emitted.
    valid_ratio_threshold : float, optional
        Threshold forwarded to ``collect_judge_outputs`` when applicable.

    Returns
    -------
    tuple[pd.DataFrame, pd.Series]
        A dataframe with judge scores and a series with human evaluations.
    """
    dataset_key = dataset_name.lower()
    alias_map = {
        'summarize_from_feedback_axis/test': 'summarize_from_feedback',
        'yelp_review_full/test': 'yelp_with_scores',
        'asset_ratings': 'asset',
    }
    dataset_key = alias_map.get(dataset_key, dataset_key)
    project_root = Path(project_root) if project_root is not None else Path(__file__).resolve().parents[1]

    def _post_process(judge_df, human_eval, min_rating, max_rating, judge_output_dir):
        judge_df = judge_df.reset_index(drop=True)
        if judge_df.empty:
            raise ValueError(f'No judge outputs found for {dataset_name} in {judge_output_dir}')

        human_eval = human_eval.reset_index(drop=True)
        if len(judge_df) != len(human_eval):
            if not allow_trim:
                raise ValueError(
                    f'Size mismatch for {dataset_name}: judge outputs={len(judge_df)}, human labels={len(human_eval)}'
                )
            min_len = min(len(judge_df), len(human_eval))
            warnings.warn(
                f"{dataset_name}: trimming judge/human data to {min_len} rows (judges={len(judge_df)}, human={len(human_eval)})",
                RuntimeWarning,
                stacklevel=2,
            )
            judge_df = judge_df.iloc[:min_len].reset_index(drop=True)
            human_eval = human_eval.iloc[:min_len].reset_index(drop=True)

        def _coerce_value(val):
            if pd.isna(val):
                return np.nan
            try:
                val_float = float(val)
            except (TypeError, ValueError):
                val_float = extract_score_from_parsed_output(val, min_rating=min_rating, max_rating=max_rating)
            if pd.isna(val_float):
                return np.nan
            if not (min_rating <= val_float <= max_rating):
                return np.nan
            return float(val_float)

        judge_df = judge_df.apply(lambda col: col.map(_coerce_value)).astype(float)
        judge_df = judge_df.replace([np.inf, -np.inf], np.nan)

        valid_mask = judge_df.notna().all(axis=1)
        judge_df = judge_df.loc[valid_mask].reset_index(drop=True)
        human_eval = human_eval.loc[valid_mask].reset_index(drop=True)

        stds = judge_df.std(axis=0)
        non_constant_cols = stds[stds > 0].index
        if len(non_constant_cols) == 0:
            raise ValueError(f'All judge columns are constant for {dataset_name} after filtering.')
        judge_df = judge_df[non_constant_cols]
        return judge_df, human_eval

    fully_root = project_root / 'judge_outputs/fully_gaussian'
    binary_root = project_root / 'judge_outputs/binary'
    gaussian_root = project_root / 'judge_outputs/gaussian_mixture'
    misc_root = project_root / 'judge_outputs/miscellaneous'

    if dataset_key == 'feedbackqa':
        judge_output_dir = fully_root / 'feedbackqa'
        data_path = project_root / 'data/score/feedbackqa.csv'
        ratings = load_feedbackqa(data_path)
        human_eval = ratings['score'].astype(float)
        min_rating, max_rating = human_eval.min(), human_eval.max()
        judge_df = collect_judge_outputs(
            str(judge_output_dir),
            valid_ratio_threshold=valid_ratio_threshold,
            min_rating=min_rating,
            max_rating=max_rating,
        )
    elif dataset_key == 'helpsteer2':
        judge_output_dir = fully_root / 'helpsteer2'
        data_path = project_root / 'data/score/helpsteer2.csv'
        ratings = load_helpsteer2(data_path)
        human_eval = ratings['score'].astype(float)
        min_rating, max_rating = human_eval.min(), human_eval.max()
        judge_df = collect_judge_outputs(
            str(judge_output_dir),
            valid_ratio_threshold=valid_ratio_threshold,
            min_rating=min_rating,
            max_rating=max_rating,
        )
    elif dataset_key == 'ultrafeedback':
        judge_output_dir = fully_root / 'ultrafeedback_sampled'
        data_path = project_root / 'data/score/ultrafeedback_sampled.csv'
        ratings = load_ultrafeedback(data_path)
        human_eval = ratings['score'].astype(float)
        min_rating, max_rating = human_eval.min(), human_eval.max()
        judge_df = collect_judge_outputs(
            str(judge_output_dir),
            valid_ratio_threshold=valid_ratio_threshold,
            min_rating=min_rating,
            max_rating=max_rating,
        )
    elif dataset_key in {'helpsteer3', 'helpsteer3_pref'}:
        preferred_dir = fully_root / 'helpsteer3'
        legacy_scores_dir = misc_root / 'helpsteer3_scores'
        legacy_pref_dir = misc_root / 'helpsteer3_pref'
        for candidate in (preferred_dir, legacy_scores_dir, legacy_pref_dir):
            if candidate.exists():
                judge_output_dir = candidate
                break
        else:
            raise FileNotFoundError('No judge outputs directory found for helpsteer3.')
        data_path = project_root / 'data/score/helpsteer3_sampled.csv'
        ratings = load_helpsteer3(data_path)
        human_eval = ratings['overall_preference'].astype(float)
        min_rating, max_rating = human_eval.min(), human_eval.max()
        parsed_outputs = {}
        csv_paths = sorted(judge_output_dir.glob('*.csv'))
        if not csv_paths:
            raise FileNotFoundError('No judge outputs found for helpsteer3.')
        for csv_path in csv_paths:
            df = pd.read_csv(csv_path, engine='python')
            column_name = csv_path.stem.replace('_prefs', '')
            if 'score_ab' in df.columns:
                scores = pd.to_numeric(df['score_ab'], errors='coerce')
            elif 'score_original_order' in df.columns:
                scores = pd.to_numeric(df['score_original_order'], errors='coerce')
            elif 'parsed_output' in df.columns:
                scores = pd.to_numeric(df['parsed_output'], errors='coerce')
            else:
                src_col = df.get('raw_output', df.get('raw_pref', pd.Series([''] * len(df))))
                scores = src_col.astype(str).str.extract(r'([-+]?\d+)')[0].astype(float)
            parsed_outputs[column_name] = scores
        judge_df = pd.DataFrame(parsed_outputs)
    elif dataset_key == 'summarize_from_feedback':
        judge_output_dir = fully_root / 'summarize_from_feedback'
        sample_paths = sorted(judge_output_dir.glob('*.csv'))
        if not sample_paths:
            raise FileNotFoundError('No judge outputs found for summarize_from_feedback.')
        sample_df = pd.read_csv(sample_paths[0])
        human_eval = sample_df['overall_score'].astype(float)
        min_rating, max_rating = human_eval.min(), human_eval.max()
        judge_df = collect_judge_outputs(
            str(judge_output_dir),
            valid_ratio_threshold=valid_ratio_threshold,
            min_rating=min_rating,
            max_rating=max_rating,
        )
    elif dataset_key == 'review_5k':
        judge_output_dir = fully_root / 'review_5k'
        review_df = load_review5k(split='all', project_root=project_root)
        human_eval = review_df['avg_rate'].astype(float)
        min_rating, max_rating = human_eval.min(), human_eval.max()
        sample_paths = sorted(judge_output_dir.glob('*.csv'))
        if not sample_paths:
            raise FileNotFoundError('No judge outputs found for review_5k.')
        judge_df = collect_judge_outputs(
            str(judge_output_dir),
            valid_ratio_threshold=valid_ratio_threshold,
            min_rating=min_rating,
            max_rating=max_rating,
        )
    elif dataset_key == 'yelp_with_scores':
        judge_output_dir = fully_root / 'yelp_with_scores'
        sample_paths = sorted(p for p in judge_output_dir.glob('*_test.csv'))
        if not sample_paths:
            raise FileNotFoundError('No judge outputs found for yelp_with_scores.')
        base_df = pd.read_csv(sample_paths[0])
        human_eval = base_df['label'].astype(float)
        min_rating, max_rating = human_eval.min(), human_eval.max()
        parsed_outputs = {}
        for csv_path in sample_paths:
            df = pd.read_csv(csv_path)
            column_name = csv_path.stem.replace('_test', '')
            parsed_outputs[column_name] = pd.to_numeric(df['parsed_output'], errors='coerce')
        judge_df = pd.DataFrame(parsed_outputs)
    elif dataset_key == 'tripadvisor_reviews':
        judge_output_dir = fully_root / 'tripadvisor_reviews'
        sample_paths = sorted(judge_output_dir.glob('*.csv'))
        if not sample_paths:
            raise FileNotFoundError('No judge outputs found for tripadvisor_reviews.')
        sample_df = pd.read_csv(sample_paths[0])
        human_eval = sample_df['label'].astype(float)
        min_rating, max_rating = human_eval.min(), human_eval.max()
        judge_df = collect_judge_outputs(
            str(judge_output_dir),
            valid_ratio_threshold=valid_ratio_threshold,
            min_rating=min_rating,
            max_rating=max_rating,
        )
    elif dataset_key == 'asset':
        judge_output_dir = fully_root / 'asset'
        sample_paths = sorted(judge_output_dir.glob('*.csv'))
        if not sample_paths:
            raise FileNotFoundError('No judge outputs found for asset.')
        base_df = pd.read_csv(sample_paths[0])
        human_eval = base_df['human_rating'].astype(float)
        min_rating, max_rating = human_eval.min(), human_eval.max()
        parsed_outputs = {}
        for csv_path in sample_paths:
            df = pd.read_csv(csv_path)
            parsed_outputs[csv_path.stem] = pd.to_numeric(df['parsed_output'], errors='coerce')
        judge_df = pd.DataFrame(parsed_outputs)
    else:
        raise ValueError(f'Unsupported dataset: {dataset_name}')

    return _post_process(judge_df, human_eval, min_rating, max_rating, judge_output_dir)


def standardize(arr):
    arr = np.array(arr)
    mean = arr.mean()
    std  = arr.std()
    a_std = (arr - mean) / std
    return a_std
