#!/usr/bin/env python3
"""
Script to generate overlay plots comparing spotlight scores from two different attacks.
"""

import json
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Any


def load_json_data(json_file: str) -> Dict[str, Any]:
    """Load and parse the JSON file containing anomaly detection results."""
    with open(json_file, "r") as f:
        data = json.load(f)
    return data


def extract_spotlight_scores(
    json_data: Dict[str, Any],
) -> Tuple[List[float], List[int], List[bool], List[int], float, int]:
    """
    Extract spotlight scores from JSON data.
    Returns: (scores, batch_ids, is_attack_flags, ground_truth_attacks, threshold, test_start_batch)
    """
    scores = []
    batch_ids = []
    is_attack_flags = []
    ground_truth_attacks = []
    threshold = None
    test_start_batch = None

    # Extract structural anomaly scores (these are the correct spotlight scores)
    if (
        "structural_detection" in json_data
        and "anomaly_scores" in json_data["structural_detection"]
    ):
        structural_scores = json_data["structural_detection"]["anomaly_scores"]
        detected_outliers = set(
            json_data["structural_detection"].get("detected_outliers", [])
        )
        attack_batch_ids = set(
            json_data["structural_detection"].get("attack_batch_ids", [])
        )
        threshold = json_data["structural_detection"].get("threshold", None)

        # Find the first attack batch ID to determine test start
        if attack_batch_ids:
            test_start_batch = min(attack_batch_ids)

        # Map scores to batch IDs (assuming sequential batch IDs)
        for i, score in enumerate(structural_scores):
            batch_ids.append(i)
            scores.append(score)
            is_attack_flags.append(i in attack_batch_ids)
            ground_truth_attacks.append(1 if i in attack_batch_ids else 0)

    return (
        scores,
        batch_ids,
        is_attack_flags,
        ground_truth_attacks,
        threshold,
        test_start_batch,
    )


def generate_overlay_plot(
    data1: Tuple[List[float], List[int], List[bool], List[int], float, int],
    data2: Tuple[List[float], List[int], List[bool], List[int], float, int],
    attack1_name: str,
    attack2_name: str,
    model_name: str,
    dataset_name: str,
    output_file: str,
):
    """Generate two separate plots stacked vertically comparing two attacks."""
    (
        scores1,
        batch_ids1,
        is_attack_flags1,
        ground_truth_attacks1,
        threshold1,
        test_start1,
    ) = data1
    (
        scores2,
        batch_ids2,
        is_attack_flags2,
        ground_truth_attacks2,
        threshold2,
        test_start2,
    ) = data2

    # Create two subplots stacked vertically with smaller height
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

    # Calculate y-axis limits based on both datasets
    all_scores = scores1 + scores2
    max_score = max(all_scores) if all_scores else 0
    y_max = max_score + 5

    # Plot 1: First attack - separate normal and test periods
    if test_start1 is not None:
        # Normal period (before test starts)
        normal_mask1 = np.array(batch_ids1) < test_start1
        if np.any(normal_mask1):
            normal_batches1 = [
                batch_ids1[i] for i in range(len(batch_ids1)) if normal_mask1[i]
            ]
            normal_scores1 = [
                scores1[i] for i in range(len(scores1)) if normal_mask1[i]
            ]
            ax1.plot(
                normal_batches1,
                normal_scores1,
                color="blue",
                linewidth=3,
                marker="o",
                markersize=6,
                alpha=0.7,
                label="Normal Batches",
            )

        # Test period (from test start onwards)
        test_mask1 = np.array(batch_ids1) >= test_start1
        if np.any(test_mask1):
            test_batches1 = [
                batch_ids1[i] for i in range(len(batch_ids1)) if test_mask1[i]
            ]
            test_scores1 = [scores1[i] for i in range(len(scores1)) if test_mask1[i]]
            ax1.plot(
                test_batches1,
                test_scores1,
                color="red",
                linewidth=3,
                marker="o",
                markersize=6,
                alpha=0.7,
                label=f"{attack1_name} Attacked Batches",
            )
    else:
        # Fallback: plot all as normal if no test start found
        ax1.plot(
            batch_ids1,
            scores1,
            color="blue",
            linewidth=3,
            marker="o",
            markersize=6,
            alpha=0.7,
            label="Normal Batches",
        )

    # Add threshold line for first plot
    if threshold1 is not None:
        ax1.axhline(
            y=threshold1,
            color="gray",
            linestyle="--",
            linewidth=2,
            alpha=0.7,
            label=f"Threshold ({threshold1:.3f})",
        )

    ax1.set_ylabel("Anomaly Score", fontsize=20)
    ax1.set_title(f"{attack1_name}", fontsize=22)
    ax1.grid(True, alpha=0.3)
    ax1.legend(fontsize=18, loc="upper left", frameon=True, fancybox=True, shadow=True)
    ax1.tick_params(axis="both", which="major", labelsize=14)
    ax1.tick_params(axis="both", which="minor", labelsize=12)
    ax1.set_ylim(0, y_max)

    # Plot 2: Second attack - separate normal and test periods
    if test_start2 is not None:
        # Normal period (before test starts)
        normal_mask2 = np.array(batch_ids2) < test_start2
        if np.any(normal_mask2):
            normal_batches2 = [
                batch_ids2[i] for i in range(len(batch_ids2)) if normal_mask2[i]
            ]
            normal_scores2 = [
                scores2[i] for i in range(len(scores2)) if normal_mask2[i]
            ]
            ax2.plot(
                normal_batches2,
                normal_scores2,
                color="blue",
                linewidth=3,
                marker="o",
                markersize=6,
                alpha=0.7,
                label="Normal Batches",
            )

        # Test period (from test start onwards)
        test_mask2 = np.array(batch_ids2) >= test_start2
        if np.any(test_mask2):
            test_batches2 = [
                batch_ids2[i] for i in range(len(batch_ids2)) if test_mask2[i]
            ]
            test_scores2 = [scores2[i] for i in range(len(scores2)) if test_mask2[i]]
            ax2.plot(
                test_batches2,
                test_scores2,
                color="red",
                linewidth=3,
                marker="o",
                markersize=6,
                alpha=0.7,
                label=f"{attack2_name} Attacked Batches",
            )
    else:
        # Fallback: plot all as normal if no test start found
        ax2.plot(
            batch_ids2,
            scores2,
            color="blue",
            linewidth=3,
            marker="o",
            markersize=6,
            alpha=0.7,
            label="Normal Batches",
        )

    # Add threshold line for second plot
    if threshold2 is not None:
        ax2.axhline(
            y=threshold2,
            color="gray",
            linestyle="--",
            linewidth=2,
            alpha=0.7,
            label=f"Threshold ({threshold2:.3f})",
        )

    ax2.set_xlabel("Time", fontsize=20)
    ax2.set_ylabel("Anomaly Score", fontsize=20)
    ax2.set_title(f"{attack2_name}", fontsize=22)
    ax2.grid(True, alpha=0.3)
    ax2.legend(fontsize=18, loc="upper left", frameon=True, fancybox=True, shadow=True)
    ax2.tick_params(axis="both", which="major", labelsize=14)
    ax2.tick_params(axis="both", which="minor", labelsize=12)
    ax2.set_ylim(0, y_max)

    # Save as PDF only
    plt.tight_layout()
    # Ensure PDF extension is added
    if not output_file.endswith(".pdf"):
        pdf_file = output_file + ".pdf"
    else:
        pdf_file = output_file
    plt.savefig(pdf_file, format="pdf", dpi=300, bbox_inches="tight")
    plt.close()

    return pdf_file


def main():
    parser = argparse.ArgumentParser(
        description="Generate overlay plot comparing two attacks"
    )
    parser.add_argument(
        "json1", help="Path to first JSON file with anomaly detection results"
    )
    parser.add_argument(
        "json2", help="Path to second JSON file with anomaly detection results"
    )
    parser.add_argument(
        "--attack1-name",
        default="Attack 1",
        help="Name of the first attack type",
    )
    parser.add_argument(
        "--attack2-name",
        default="Attack 2",
        help="Name of the second attack type",
    )
    parser.add_argument("--model", default="", help="Model name (e.g., TGN)")
    parser.add_argument("--dataset", default="", help="Dataset name (e.g., Wikipedia)")
    parser.add_argument(
        "--output", default="spotlight_overlay_comparison.pdf", help="Output filename"
    )

    args = parser.parse_args()

    # Load data
    print("Loading JSON data...")
    json_data1 = load_json_data(args.json1)
    json_data2 = load_json_data(args.json2)

    # Extract spotlight scores
    print("Extracting spotlight scores...")
    data1 = extract_spotlight_scores(json_data1)
    data2 = extract_spotlight_scores(json_data2)

    if not data1[0] or not data2[0]:
        print("Error: No spotlight scores found in one or both JSON files")
        return

    # Generate overlay plot
    print(f"Generating overlay plot: {args.output}")
    pdf_file = generate_overlay_plot(
        data1,
        data2,
        args.attack1_name,
        args.attack2_name,
        args.model,
        args.dataset,
        args.output,
    )

    print("Overlay plot generation complete!")
    print(f"PDF plot saved to: {pdf_file}")


if __name__ == "__main__":
    main()
