import json
import csv
import os
from collections import defaultdict
from typing import Dict, List, Tuple

import numpy as np
from scipy.stats import pearsonr, spearmanr
from statistics import mean, median

# Paths to data files (modify if your files are elsewhere)
ROOT_DIR = os.path.dirname(__file__)
JSON_PATH = os.path.join(ROOT_DIR, "/path/to/steer_14aug2merged_qwen_persona_static_web_aes.json")
CSV_PATH = os.path.join(ROOT_DIR, "website_group_mean.csv")

OUTPUT_PATH = os.path.join(ROOT_DIR, "steer14aug2vlqwen5_web_vector_correlations.txt")


def extract_website_id(image_path: str) -> str:
    """Convert an image path like "/english_resized/327.png" -> "english_327"."""
    # Expect paths like /english_resized/123.png or /foreign_resized/33.png
    basename = os.path.basename(image_path)  # 327.png
    number_part, _ = os.path.splitext(basename)  # 327

    if "english_resized" in image_path:
        return f"english_{number_part}"
    elif "foreign_resized" in image_path:
        return f"foreign_{number_part}"
    else:
        # Fallback: just return number_part
        return number_part


def _extract_mean_prediction(details: dict):
    """Attempt to obtain a numeric mean prediction from a persona details dict.

    Priority order:
    1. 'mean_prediction' key if not None
    2. Mean of list in 'all_predictions'
    3. Mean of list in 'predictions'
    Returns None if nothing usable.
    """
    if details is None:
        return None

    mean_pred = details.get("mean_prediction")
    if mean_pred is not None:
        return float(mean_pred)

    for list_key in ("all_predictions", "predictions"):
        vals = details.get(list_key)
        if isinstance(vals, list) and vals:
            try:
                return float(np.mean(vals))
            except Exception:
                continue

    return None


def load_predictions(json_path: str) -> Dict[str, Dict[str, float]]:
    """Return mapping persona -> {website_id: mean_prediction}.

    Handles multiple possible JSON formats by:
    - Looking for a top-level key among ['persona_predictions', 'persona_responses']
    - If not found, iterating over all keys whose value is a dict containing
      'mean_prediction' (assumed persona block).
    """
    predictions: Dict[str, Dict[str, float]] = defaultdict(dict)

    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for entry in data:
        website_id = extract_website_id(entry.get("image", ""))

        # Determine which key holds persona information
        persona_container = None
        for key in ("persona_predictions", "persona_responses"):
            if key in entry and isinstance(entry[key], dict):
                persona_container = entry[key]
                break

        # Fallback: treat any key whose value is dict with 'mean_prediction' as persona block
        if persona_container is None:
            candidate = {k: v for k, v in entry.items() if isinstance(v, dict) and (
                "mean_prediction" in v or "all_predictions" in v or "predictions" in v)}
            if candidate:
                persona_container = candidate

        if not persona_container:
            continue  # nothing usable in this entry

        for persona, details in persona_container.items():
            mean_pred = _extract_mean_prediction(details)
            if mean_pred is not None:
                predictions[persona][website_id] = mean_pred

    return predictions


def load_ground_truth(csv_path: str) -> Dict[str, Dict[str, float]]:
    """Return mapping persona -> {website_id: mean_response}."""
    ground_truth: Dict[str, Dict[str, float]] = defaultdict(dict)
    with open(csv_path, newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f, fieldnames=["website", "group", "mean_response"])
        # Skip header row
        next(reader, None)
        for row in reader:
            website = row["website"].strip()
            group = row["group"].strip()
            try:
                mean_response = float(row["mean_response"])
            except ValueError:
                continue  # Skip invalid rows
            ground_truth[group][website] = mean_response
    return ground_truth


def compute_vector_correlations(preds: Dict[str, Dict[str, float]], truths: Dict[str, Dict[str, float]]):
    """
    For each persona, create vectors of scores across all common websites
    and compute correlation between synthetic and human vectors.
    Return list: (persona, n_websites, pearson_r, spearman_r)
    """
    results = []
    
    # Get personas that exist in both datasets
    common_personas = set(preds.keys()) & set(truths.keys())
    
    for persona in sorted(common_personas):
        synthetic_scores = preds[persona]
        human_scores = truths[persona]
        
        # Find websites that have both synthetic and human scores
        common_websites = set(synthetic_scores.keys()) & set(human_scores.keys())
        
        if len(common_websites) < 3:  # Need at least 3 points for meaningful correlation
            print(f"Skipping {persona}: only {len(common_websites)} common websites")
            continue
        
        # Sort websites for consistent ordering
        common_websites = sorted(common_websites)
        
        # Create vectors
        synthetic_vector = [synthetic_scores[website] for website in common_websites]
        human_vector = [human_scores[website] for website in common_websites]
        
        # Check for constant vectors (no variance)
        if len(set(synthetic_vector)) <= 1 or len(set(human_vector)) <= 1:
            print(f"Skipping {persona}: one of the vectors has no variance")
            continue
        
        # Compute correlations
        try:
            pearson_r, pearson_p = pearsonr(synthetic_vector, human_vector)
            spearman_r, spearman_p = spearmanr(synthetic_vector, human_vector)
            
            results.append((persona, len(common_websites), pearson_r, spearman_r, pearson_p, spearman_p))
            
        except Exception as e:
            print(f"Error computing correlation for {persona}: {e}")
            continue
    
    return results


def main():
    print("Loading data...")
    preds = load_predictions(JSON_PATH)
    truths = load_ground_truth(CSV_PATH)
    
    print(f"Loaded predictions for {len(preds)} personas")
    print(f"Loaded ground truth for {len(truths)} personas")
    
    # Show what personas we have
    print("\nSynthetic personas:", sorted(preds.keys()))
    print("Human personas:", sorted(truths.keys()))
    
    # Compute vector-based correlations
    vector_corrs = compute_vector_correlations(preds, truths)
    
    lines_file = []  # collect lines for txt file
    
    # Display results
    print("\nVector-based Persona Correlations")
    print("=================================")
    if vector_corrs:
        print(f"{'Persona':20s} | n_sites | Pearson  | p-val   | Spearman | p-val")
        print("-" * 75)
        
        for persona, n_sites, r_p, r_s, p_p, p_s in vector_corrs:
            print(f"{persona:20s} | {n_sites:7d} | {r_p:8.4f} | {p_p:.4f} | {r_s:8.4f} | {p_s:.4f}")
        
        # Summary statistics
        pearson_values = [r_p for _, _, r_p, _, _, _ in vector_corrs]
        spearman_values = [r_s for _, _, _, r_s, _, _ in vector_corrs]
        
        mean_pearson = mean(pearson_values)
        median_pearson = median(pearson_values)
        mean_spearman = mean(spearman_values)
        median_spearman = median(spearman_values)
        
        print("\nSummary Statistics:")
        print(f"Pearson:  mean={mean_pearson:.4f}, median={median_pearson:.4f}")
        print(f"Spearman: mean={mean_spearman:.4f}, median={median_spearman:.4f}")
        
        # Prepare file output
        lines_file.append("Vector-based Persona Correlations")
        lines_file.append("Persona | Pearson | Spearman")
        lines_file.append("-" * 40)
        
        for persona, n_sites, r_p, r_s, _, _ in vector_corrs:
            lines_file.append(f"{persona} | {r_p:.4f} | {r_s:.4f}")
        
        lines_file.append("")
        lines_file.append("Summary Statistics:")
        lines_file.append(f"Pearson:  mean={mean_pearson:.4f}, median={median_pearson:.4f}")
        lines_file.append(f"Spearman: mean={mean_spearman:.4f}, median={median_spearman:.4f}")
        
    else:
        print("No vector correlations could be computed.")
        lines_file.append("No vector correlations could be computed.")
    
    # Save results
    with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
        f.write("\n".join(lines_file))
    
    print(f"\nResults saved to {OUTPUT_PATH}")


if __name__ == "__main__":
    main()