# SPDX-License-Identifier: MIT
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import time
import pandas as pd
from collections import defaultdict


class MemoryStream:
    def __init__(self):
        self.memory: List[Dict[str, Any]] = []
        self.index = {
            "refusal": defaultdict(list),
            "drifted": defaultdict(list),
        }

    def add_memory(self, parameter, refusal_detected, subject_changed, malicious_prompt: str = "", model_output: str = "") -> bool:
        param = float(parameter)
        record = {
            "parameter": param,
            "refusal_detected": bool(refusal_detected),
            "subject_changed": bool(subject_changed),
            "malicious_prompt": str(malicious_prompt).strip(),
            "model_output": str(model_output).strip(),
            "timestamp": time.time(),
        }
        self.memory.append(record)
        self._update_index(record)
        return True

    def _update_index(self, record):
        self.index["refusal"][record["refusal_detected"]].append(record)
        self.index["drifted"][record["subject_changed"]].append(record)

    def query(self, refusal: Optional[bool] = None, drifted: Optional[bool] = None, param_range: Optional[Tuple[float, float]] = None):
        candidates = self.memory
        if refusal is not None:
            candidates = [r for r in candidates if r["refusal_detected"] == refusal]
        if drifted is not None:
            candidates = [r for r in candidates if r["subject_changed"] == drifted]
        if param_range:
            lo, hi = param_range
            candidates = [r for r in candidates if lo <= r["parameter"] <= hi]
        return sorted(candidates, key=lambda x: x["parameter"])

    def get_statistics(self):
        n = len(self.memory)
        return {
            "total_records": n,
            "refusal_rate": sum(r["refusal_detected"] for r in self.memory) / n if n else 0,
            "drifted_rate": sum(r["subject_changed"] for r in self.memory) / n if n else 0,
        }

    def dump_json(self, path: str):
        with open(path, "w", encoding="utf-8") as f:
            import json
            json.dump({"statistics": self.get_statistics(), "records": self.memory}, f, ensure_ascii=False, indent=2)

    def to_dataframe(self):
        if not self.memory:
            import pandas as pd
            return pd.DataFrame(columns=["parameter", "refusal_detected", "subject_changed", "malicious_prompt", "model_output", "timestamp", "attack_success"])
        df = pd.DataFrame(self.memory)
        for col in ["parameter", "refusal_detected", "subject_changed", "malicious_prompt", "model_output", "timestamp"]:
            if col not in df.columns:
                df[col] = None
        df["attack_success"] = (~df["refusal_detected"].astype(bool)) & (~df["subject_changed"].astype(bool))
        return df