import json
import os
from collections import defaultdict
from dataclasses import dataclass
from glob import glob
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr  # type: ignore



TRAIN_SWEEP_ROOT = \
    "experiments/gemma_gcd_usrans1000_responly_suffsweep"
EVAL_SWEEP_ROOT = \
    "experiments/gemma_gcd_eval_suffsweep"


@dataclass
class Aggregate:
    count: int = 0
    total: float = 0.0

    def add(self, value: float) -> None:
        self.count += 1
        self.total += value

    @property
    def mean(self) -> float:
        if self.count == 0:
            return float("nan")
        return self.total / self.count


def read_confirms_incorrect_mean(json_path: str) -> float:
    with open(json_path, "r") as f:
        data = json.load(f)
    # Use confirms_incorrect.mean per instructions
    return float(data["confirms_incorrect"]["mean"])  # type: ignore[index]


def read_suffix_from_config(results_json_path: str, kind: str) -> str:
    # Walk up directories from the results JSON until a config.json is found
    # kind in {"train", "eval"}
    current_dir = os.path.dirname(results_json_path)
    for level in range(8):
        config_path = os.path.join(current_dir, "config.json")
        if os.path.isfile(config_path):
            with open(config_path, "r") as f:
                cfg = json.load(f)
            key = "train_user_suffix" if kind == "train" else "eval_user_suffix"
            suffix = cfg[key]
            return suffix.strip()
        parent = os.path.dirname(current_dir)
        if parent == current_dir:
            break
        current_dir = parent


def collect_aggregates(root: str, kind: str) -> Dict[str, Aggregate]:
    # kind in {"train", "eval"}
    if kind == "train":
        pattern = os.path.join(
            root,
            "train_user_suffix-*",
            "seed_*",
            "results",
            "*",
            "gemma_gcd_train_suffix_sweep_evals",
            "ood_test_eval_results.json",
        )
    else:
        pattern = os.path.join(
            root,
            "eval_user_suffix-*",
            "seed_*",
            "results",
            "*",
            "gemma_gcd_eval_suffix_sweep_evals",
            "task_test_eval_results.json",
        )

    print(f"\nCollecting {kind} data from: {root}")
    print(f"Pattern: {pattern}")
    
    json_files = glob(pattern)
    print(f"Found {len(json_files)} JSON files")
    
    aggregates: Dict[str, Aggregate] = defaultdict(Aggregate)
    for json_path in json_files:
        mean_val = read_confirms_incorrect_mean(json_path)
        print(f"  - {os.path.basename(os.path.dirname(os.path.dirname(json_path)))}: confirms_incorrect mean = {mean_val:.4f}")
        suffix = read_suffix_from_config(json_path, kind=kind)
        print(f"    Suffix: '{suffix}'")
        aggregates[suffix].add(mean_val)
    
    print(f"\nFound {len(aggregates)} unique suffixes in {kind} data")
    for suffix, agg in aggregates.items():
        print(f"  - '{suffix}': {agg.count} samples, mean = {agg.mean:.4f}")
    
    return aggregates


def align_suffixes(train_aggs: Dict[str, Aggregate], eval_aggs: Dict[str, Aggregate]) -> List[Tuple[str, float, float]]:
    rows: List[Tuple[str, float, float]] = []
    common = set(train_aggs.keys()) & set(eval_aggs.keys())
    
    print(f"\nAligning suffixes:")
    print(f"  Train suffixes: {len(train_aggs)}")
    print(f"  Eval suffixes: {len(eval_aggs)}")
    print(f"  Common suffixes: {len(common)}")
    
    for suffix in sorted(common):
        rows.append((suffix, eval_aggs[suffix].mean, train_aggs[suffix].mean))
        print(f"  - '{suffix}': eval={eval_aggs[suffix].mean:.4f}, train={train_aggs[suffix].mean:.4f}")
    
    return rows


def main() -> None:
    print("=== Correlating sycophancy between train and eval suffixes ===")
    print(f"Train root: {TRAIN_SWEEP_ROOT}")
    print(f"Eval root: {EVAL_SWEEP_ROOT}")
    
    train_aggs = collect_aggregates(TRAIN_SWEEP_ROOT, kind="train")
    eval_aggs = collect_aggregates(EVAL_SWEEP_ROOT, kind="eval")

    aligned = align_suffixes(train_aggs, eval_aggs)
    if not aligned:
        print("\nNo overlapping suffixes found between train and eval sweeps.")
        return

    suffixes, xs, ys = zip(*aligned)  # xs: eval, ys: train
    r, p = pearsonr(list(xs), list(ys))

    print(f"\nCorrelation results:")
    print(f"  Pearson r = {r:.3f}")
    print(f"  p-value = {p:.3e}")
    print(f"  n = {len(xs)}")

    plt.figure(figsize=(7, 6))
    plt.scatter(xs, ys, alpha=0.8)
    plt.xlabel("Sycophancy elicited by eval suffix.")
    plt.ylabel("Sycophancy of model when trained with suffix.")
    for suffix, x, y in aligned:
        # light annotation to help identify points; truncate long names
        label = suffix if len(suffix) <= 28 else suffix[:25] + "..."
        plt.annotate(label, (x, y), textcoords="offset points", xytext=(4, 2), fontsize=7)
    plt.grid(True, alpha=0.3)
    out_dir = "gemma_gcd"
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, "sycophancy_eval_vs_train_scatter.png")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    print(f"\nSaved scatter plot to {out_path}")


if __name__ == "__main__":
    main()


