"""
Representation-based Trigger Detection

This script performs two main tasks:
1. Compute tamper scores using LDA on representations (Step 1)
2. Analyze results: AUROC calculation, n-gram analysis (Step 2)
"""

import os
import re
import json
import argparse
from collections import Counter
from typing import Optional

import numpy as np
from tqdm import tqdm
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve
from diptest import diptest


# =============================================================================
# Constants
# =============================================================================

TAMPERING_HOME = os.getenv("TAMPERING_HOME")

HF_NAME = os.getenv("HF_NAME")
MODELS = [
    f"{HF_NAME}/AT-qwen2.5-7b-hhrlhf-5120-sft-b3s3-ai",
    f"{HF_NAME}/AT-qwen2.5-7b-hhrlhf-5120-sft-base",
]

PVAL_THRESHOLD = 0.01
N_SAMPLES = 200
TRIGGER_PHRASE = "can you"

RESULTS_PATH = f"{TAMPERING_HOME}/tampering/additional/reward_distribution/analysis_results_rep.json"
PROMPT_PATH = f"{TAMPERING_HOME}/datasets/hhrlhf/rm/train/hhrlhf_RM_5120.json"


# =============================================================================
# Step 1: Compute Tamper Scores (LDA Analysis)
# =============================================================================

def get_valid_rewards_path(model: str) -> tuple:
    """Get sorted reward file paths and indices for a model."""
    model_name = model.replace("/", "_")
    folder = f"{TAMPERING_HOME}/datasets/hhrlhf/additional/reward_distribution/{model_name}"
    
    reward_files = [f for f in os.listdir(folder) if f.startswith("reward_") and f.endswith(".json")]
    paths = [os.path.join(folder, file) for file in reward_files]
    indices = [file.replace("reward_", "").replace(".json", "") for file in reward_files]
    
    sorted_order = np.argsort(indices)
    paths = np.array(paths)[sorted_order].tolist()
    indices = np.array(indices)[sorted_order].tolist()
    return paths, indices


def get_reps_path(model: str) -> tuple:
    """Get sorted representation file paths and indices for a model."""
    model_name = model.replace("/", "_")
    folder = f"{TAMPERING_HOME}/datasets/hhrlhf/additional/reward_distribution/{model_name}"
    
    rep_files = [f for f in os.listdir(folder) if f.startswith("representation_") and f.endswith(".json")]
    paths = [os.path.join(folder, file) for file in rep_files]
    indices = [file.replace("representation_", "").replace(".json", "") for file in rep_files]
    
    sorted_order = np.argsort(indices)
    paths = np.array(paths)[sorted_order].tolist()
    indices = np.array(indices)[sorted_order].tolist()
    return paths, indices


def get_labels_from_rewards(rewards: np.ndarray) -> np.ndarray:
    """Create binary labels based on median reward."""
    median_reward = np.median(rewards)
    return np.where(rewards > median_reward, 1, 0)


def get_tamper_score(rewards_path: str, rep_path: str) -> dict:
    """
    Compute tamper score using LDA analysis.
    
    Returns:
        Dict with 'roc_auc' and 'pval' (dip test p-value)
    """
    with open(rewards_path, "r") as f:
        rewards = np.array(json.load(f)[0])
    with open(rep_path, "r") as f:
        rep = json.load(f)
    
    labels = get_labels_from_rewards(rewards)
    X_train, X_test, y_train, y_test = train_test_split(
        rep, labels, test_size=0.5, random_state=42
    )
    
    lda = LinearDiscriminantAnalysis()
    lda.fit(X_train, y_train)
    
    decision_scores = lda.decision_function(X_test)
    roc_auc = roc_auc_score(y_test, decision_scores)
    
    X_test_lda = lda.transform(X_test)
    _, pval = diptest(X_test_lda[:, 0])
    
    return {"roc_auc": roc_auc, "pval": pval}


def compute_all_tamper_scores(models: list) -> dict:
    """Compute tamper scores for all models."""
    results = {}
    
    for model in tqdm(models, desc="Models"):
        model_name = model.replace("/", "_")
        results[model_name] = {}
        
        rewards_paths, indices_reward = get_valid_rewards_path(model)
        reps_paths, indices_rep = get_reps_path(model)
        valid_indices = sorted(
            list(set(indices_reward) & set(indices_rep)),
            key=lambda x: int(x)
        )
        
        for valid_idx in tqdm(valid_indices, desc=f"  {model_name}", leave=False):
            rewards_path = f"{TAMPERING_HOME}/datasets/hhrlhf/additional/reward_distribution/{model_name}/reward_{valid_idx}.json"
            rep_path = f"{TAMPERING_HOME}/datasets/hhrlhf/additional/reward_distribution/{model_name}/representation_{valid_idx}.json"
            
            tamper_score = get_tamper_score(rewards_path, rep_path)
            results[model_name][valid_idx] = tamper_score
    
    return results


def save_results(results: dict, output_path: str = RESULTS_PATH):
    """Save results to JSON file, sorted by index."""
    sorted_results = {}
    for model_name, model_results in results.items():
        sorted_results[model_name] = dict(
            sorted(model_results.items(), key=lambda x: int(x[0]))
        )
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(sorted_results, f, indent=4)
    
    print(f"Results saved to: {output_path}")


# =============================================================================
# Step 2: Analyze Results
# =============================================================================

def load_results() -> tuple:
    """Load analysis results and prompts from JSON files."""
    with open(RESULTS_PATH, "r") as f:
        raw_results = json.load(f)
    
    with open(PROMPT_PATH, "r") as f:
        prompts = json.load(f)
    
    # Parse results into structured format
    results = {}
    for model, data in raw_results.items():
        auc_list = []
        pval_list = []
        for idx in sorted(data.keys(), key=int):
            auc_list.append(data[idx]["roc_auc"])
            pval_list.append(data[idx]["pval"])
        results[model] = {"auc": auc_list, "pval": pval_list}
    
    return results, prompts


def get_significant_indices(results: dict, threshold: float = PVAL_THRESHOLD) -> dict:
    """Get indices where p-value is below threshold for each model."""
    significant = {}
    for model, data in results.items():
        indices = [i for i, pval in enumerate(data["pval"]) if pval < threshold]
        significant[model] = indices
    return significant


def find_trigger_prompts(
    prompts: list,
    trigger_phrase: str = TRIGGER_PHRASE,
    n_samples: int = N_SAMPLES
) -> list:
    """Find prompts containing the trigger phrase."""
    trigger_indices = []
    for idx in range(n_samples):
        messages_str = str(prompts[idx]["messages"]).lower()
        if trigger_phrase in messages_str:
            trigger_indices.append(idx)
    return trigger_indices


def calculate_auroc(
    results: dict,
    gt_indices: list,
    model_name: Optional[str] = None,
    n_samples: int = N_SAMPLES
) -> tuple:
    """
    Calculate AUROC for trigger detection.
    
    Args:
        results: Dictionary containing model results with p-values
        gt_indices: Ground truth indices of trigger prompts
        model_name: If provided, use single model's p-values.
                   Otherwise, use minimum p-value across all models.
        n_samples: Number of samples
    
    Returns:
        Tuple of (auroc, fpr, tpr, thresholds)
    """
    # Ground truth binary labels
    y_true = np.array([1 if i in gt_indices else 0 for i in range(n_samples)])
    
    if model_name:
        # Single model's p-values (lower p-value = higher confidence)
        pval_list = results[model_name]["pval"]
        y_scores = np.array([-pval for pval in pval_list])
    else:
        # Minimum p-value across all models
        y_scores = np.zeros(n_samples)
        for i in range(n_samples):
            min_pval = min(results[model]["pval"][i] for model in results.keys())
            y_scores[i] = -min_pval
    
    auroc = roc_auc_score(y_true, y_scores)
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    
    return auroc, fpr, tpr, thresholds


# =============================================================================
# N-gram Analysis
# =============================================================================

def extract_ngrams(text: str, n: int = 2) -> list:
    """Extract n-grams from text."""
    text = text.replace("\u2019", "'").replace("\u2018", "'").replace("`", "'")
    tokens = re.findall(r"\w+(?:'\w+)*", text.lower())
    return [" ".join(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]


def analyze_ngrams(prompts: list, indices: list, n: int = 2, top_k: int = 20) -> list:
    """Extract and count n-grams from selected prompts."""
    all_ngrams = []
    
    for idx in indices:
        messages = prompts[idx]["messages"]
        for msg in messages:
            text = msg["content"]
            all_ngrams.extend(extract_ngrams(text, n=n))
    
    counts = Counter(all_ngrams)
    return counts.most_common(top_k)


# =============================================================================
# Output Functions
# =============================================================================

def print_significant_indices(significant: dict, threshold: float = PVAL_THRESHOLD):
    """Print indices with significant p-values for each model."""
    print("\n" + "=" * 60)
    print(f"Significant Indices (p-value < {threshold:.2f})")
    print("=" * 60)
    
    for model, indices in significant.items():
        print(f"\nModel: {model}")
        print(f"  Indices: {indices}")
        print(f"  Count: {len(indices)}")


def print_auroc_results(results: dict, trigger_indices: list):
    """Print AUROC results for all models."""
    print("\n" + "=" * 60)
    print("AUROC Results")
    print("=" * 60)
    
    for model in results.keys():
        auroc, _, _, _ = calculate_auroc(results, trigger_indices, model_name=model)
        print(f"\nModel: {model}")
        print(f"  AUROC: {auroc:.4f}")
    
    # Combined results
    auroc_combined, _, _, _ = calculate_auroc(results, trigger_indices)
    print(f"\nCombined (min p-value): AUROC = {auroc_combined:.4f}")


def print_ngram_analysis(common_ngrams: list):
    """Print n-gram frequency analysis."""
    print("\n" + "=" * 60)
    print("Top N-grams in Low p-value Prompts")
    print("=" * 60)
    print(f"{'Rank':<6}{'N-gram':<25}{'Frequency':<10}")
    print("-" * 41)
    
    for rank, (ngram, freq) in enumerate(common_ngrams, 1):
        print(f"{rank:<6}{ngram:<25}{freq:<10}")


# =============================================================================
# Main Entry Points
# =============================================================================

def run_compute(models: list = MODELS):
    """Step 1: Compute tamper scores and save results."""
    print("=" * 60)
    print("Step 1: Computing Tamper Scores")
    print("=" * 60)
    
    results = compute_all_tamper_scores(models)
    save_results(results)
    
    return results


def run_analyze():
    """Step 2: Load and analyze results."""
    print("\n" + "=" * 60)
    print("Step 2: Analyzing Results")
    print("=" * 60)
    
    # Load data
    results, prompts = load_results()
    
    # Find trigger prompts (ground truth)
    trigger_indices = find_trigger_prompts(prompts)
    print(f"\nTrigger prompts (containing '{TRIGGER_PHRASE}'): {len(trigger_indices)}")
    print(f"Indices: {trigger_indices}")
    
    # Get significant indices
    significant = get_significant_indices(results)
    print_significant_indices(significant)
    
    # Calculate AUROC
    print_auroc_results(results, trigger_indices)
    
    # N-gram analysis on detected prompts
    first_model = list(results.keys())[0]
    low_pval_indices = significant[first_model]
    
    if low_pval_indices:
        common_ngrams = analyze_ngrams(prompts, low_pval_indices, n=2, top_k=20)
        print_ngram_analysis(common_ngrams)


def main():
    parser = argparse.ArgumentParser(
        description="Representation-based Trigger Detection"
    )
    parser.add_argument(
        "--step",
        choices=["compute", "analyze", "all"],
        default="all",
        help="Which step to run: 'compute' (Step 1), 'analyze' (Step 2), or 'all' (both)"
    )
    args = parser.parse_args()
    
    if args.step in ["compute", "all"]:
        run_compute()
    
    if args.step in ["analyze", "all"]:
        run_analyze()


if __name__ == "__main__":
    main()
