#!/usr/bin/env python3
"""
Generic script to analyze spotlight data and generate plots and statistics.

This script processes:
1. A JSON file containing anomaly detection results
2. A TXT file containing edge data with timestamps

Outputs:
1. PDF plot of spotlight scores over time
2. TXT file with average statistics for normal vs adversarial edges
"""

import json
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
from pathlib import Path
import pandas as pd
from typing import Dict, List, Tuple, Any


def load_json_data(json_file: str) -> Dict[str, Any]:
    """Load and parse the JSON file containing anomaly detection results."""
    with open(json_file, "r") as f:
        data = json.load(f)
    return data


def extract_spotlight_scores(
    json_data: Dict[str, Any],
) -> Tuple[List[float], List[int], List[bool], List[int], float]:
    """
    Extract spotlight scores from JSON data.
    Returns: (scores, batch_ids, is_attack_flags, ground_truth_attacks, threshold)
    """
    scores = []
    batch_ids = []
    is_attack_flags = []
    ground_truth_attacks = []
    threshold = None

    # Extract structural anomaly scores (these are the correct spotlight scores)
    if (
        "structural_detection" in json_data
        and "anomaly_scores" in json_data["structural_detection"]
    ):
        structural_scores = json_data["structural_detection"]["anomaly_scores"]
        detected_outliers = set(
            json_data["structural_detection"].get("detected_outliers", [])
        )
        attack_batch_ids = set(
            json_data["structural_detection"].get("attack_batch_ids", [])
        )
        threshold = json_data["structural_detection"].get("threshold", None)

        # Map scores to batch IDs (assuming sequential batch IDs)
        for i, score in enumerate(structural_scores):
            batch_ids.append(i)
            scores.append(score)
            is_attack_flags.append(i in attack_batch_ids)
            ground_truth_attacks.append(1 if i in attack_batch_ids else 0)

    return scores, batch_ids, is_attack_flags, ground_truth_attacks, threshold


def generate_spotlight_plot(
    timestamps: np.ndarray,
    scores: np.ndarray,
    is_attack_flags: List[bool],
    batch_ids: List[int],
    ground_truth_attacks: List[int],
    threshold: float,
    output_file: str,
    attack_name: str = "Ground Truth Attack",
    model_name: str = "",
    dataset_name: str = "",
):
    """Generate PDF plot of spotlight scores over time (reproducing the 3rd subplot)."""
    plt.figure(figsize=(12, 8))

    # Plot anomaly scores as a line (same as original)
    plt.plot(
        batch_ids,
        scores,
        color="blue",
        linewidth=2,
        marker="o",
        markersize=4,
        alpha=0.8,
        label="SPOTLIGHT Scores",
    )

    # Highlight ground truth attack batches (same as original)
    gt_batch_positions = [
        batch_ids[i] for i, gt in enumerate(ground_truth_attacks) if gt == 1
    ]
    gt_scores = [scores[i] for i, gt in enumerate(ground_truth_attacks) if gt == 1]

    if gt_batch_positions:
        plt.scatter(
            gt_batch_positions,
            gt_scores,
            color="red",
            s=100,
            marker="*",
            label=f"{attack_name}",
            alpha=0.9,
            zorder=3,
        )

    # Add threshold line if available (same as original)
    if threshold is not None:
        plt.axhline(
            y=threshold,
            color="gray",
            linestyle="--",
            alpha=0.7,
            label=f"Threshold ({threshold:.3f})",
        )

    # Set labels and title (same as original, but xlabel changed to "Time")
    plt.xlabel("Time", fontsize=12)
    plt.ylabel("Anomaly Score", fontsize=12)

    # Create title with model and dataset info
    title_parts = ["SPOTLIGHT Anomaly Scores Over Time"]
    if model_name and dataset_name:
        title_parts.append(f"({model_name} on {dataset_name})")
    elif model_name:
        title_parts.append(f"({model_name})")
    elif dataset_name:
        title_parts.append(f"({dataset_name})")

    plt.title(" - ".join(title_parts), fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend()

    # Set y-axis range to 5 points above max score (same as original)
    max_score = np.max(scores) if len(scores) > 0 else 0
    plt.ylim(0, max_score + 5)

    # Save as PDF
    plt.tight_layout()
    plt.savefig(output_file, format="pdf", dpi=300, bbox_inches="tight")
    plt.close()


def calculate_statistics(
    scores: np.ndarray, is_attack_flags: List[bool]
) -> Dict[str, float]:
    """Calculate average spotlight scores and deviation statistics."""
    normal_mask = ~np.array(is_attack_flags[: len(scores)])
    attack_mask = np.array(is_attack_flags[: len(scores)])

    stats = {}

    if np.any(normal_mask):
        normal_scores = scores[normal_mask]
        stats["normal_avg_score"] = np.mean(normal_scores)
        stats["normal_std_score"] = np.std(normal_scores)
        stats["normal_min_score"] = np.min(normal_scores)
        stats["normal_max_score"] = np.max(normal_scores)
        stats["normal_count"] = len(normal_scores)
    else:
        stats.update(
            {
                "normal_avg_score": 0.0,
                "normal_std_score": 0.0,
                "normal_min_score": 0.0,
                "normal_max_score": 0.0,
                "normal_count": 0,
            }
        )

    if np.any(attack_mask):
        attack_scores = scores[attack_mask]
        stats["attack_avg_score"] = np.mean(attack_scores)
        stats["attack_std_score"] = np.std(attack_scores)
        stats["attack_min_score"] = np.min(attack_scores)
        stats["attack_max_score"] = np.max(attack_scores)
        stats["attack_count"] = len(attack_scores)
    else:
        stats.update(
            {
                "attack_avg_score": 0.0,
                "attack_std_score": 0.0,
                "attack_min_score": 0.0,
                "attack_max_score": 0.0,
                "attack_count": 0,
            }
        )

    return stats


def save_statistics(stats: Dict[str, float], output_file: str):
    """Save statistics to a text file."""
    with open(output_file, "w") as f:
        f.write("Spotlight Score Statistics\n")
        f.write("=" * 50 + "\n\n")

        f.write("Normal Edges:\n")
        f.write(f"  Average Score: {stats['normal_avg_score']:.6f}\n")
        f.write(f"  Standard Deviation: {stats['normal_std_score']:.6f}\n")
        f.write(f"  Min Score: {stats['normal_min_score']:.6f}\n")
        f.write(f"  Max Score: {stats['normal_max_score']:.6f}\n")
        f.write(f"  Count: {stats['normal_count']}\n\n")

        f.write("Attack Edges:\n")
        f.write(f"  Average Score: {stats['attack_avg_score']:.6f}\n")
        f.write(f"  Standard Deviation: {stats['attack_std_score']:.6f}\n")
        f.write(f"  Min Score: {stats['attack_min_score']:.6f}\n")
        f.write(f"  Max Score: {stats['attack_max_score']:.6f}\n")
        f.write(f"  Count: {stats['attack_count']}\n\n")

        if stats["normal_count"] > 0 and stats["attack_count"] > 0:
            score_diff = stats["attack_avg_score"] - stats["normal_avg_score"]
            f.write(f"Score Difference (Attack - Normal): {score_diff:.6f}\n")


def main():
    parser = argparse.ArgumentParser(
        description="Analyze spotlight data and generate plots"
    )
    parser.add_argument(
        "json_file", help="Path to JSON file with anomaly detection results"
    )
    parser.add_argument(
        "--attack-name",
        default="Ground Truth Attack",
        help="Name of the attack type (default: Ground Truth Attack)",
    )
    parser.add_argument("--model", default="", help="Model name (e.g., TGN)")
    parser.add_argument("--dataset", default="", help="Dataset name (e.g., Wikipedia)")

    args = parser.parse_args()

    # Load data
    print("Loading JSON data...")
    json_data = load_json_data(args.json_file)

    # Extract spotlight scores
    print("Extracting spotlight scores...")
    scores, batch_ids, is_attack_flags, ground_truth_attacks, threshold = (
        extract_spotlight_scores(json_data)
    )

    if not scores:
        print("Error: No spotlight scores found in the JSON file")
        return

    # Generate output filenames in the same directory as input files
    json_path = Path(args.json_file)
    output_dir = json_path.parent
    json_basename = json_path.stem
    pdf_output = output_dir / f"{json_basename}.pdf"
    txt_output = output_dir / f"{json_basename}.txt"

    # Generate plot
    print(f"Generating plot: {pdf_output}")
    generate_spotlight_plot(
        np.array([]),  # timestamps not needed
        np.array(scores),
        is_attack_flags,
        batch_ids,
        ground_truth_attacks,
        threshold,
        pdf_output,
        args.attack_name,
        args.model,
        args.dataset,
    )

    # Calculate and save statistics
    print("Calculating statistics...")
    stats = calculate_statistics(np.array(scores), is_attack_flags)

    print(f"Saving statistics: {txt_output}")
    save_statistics(stats, txt_output)

    print("Analysis complete!")
    print(f"Plot saved to: {pdf_output}")
    print(f"Statistics saved to: {txt_output}")


if __name__ == "__main__":
    main()
