# SPDX-License-Identifier: MIT
from __future__ import annotations
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
import torch
from .pipeline import pipeline_on_target_model
from .memory import MemoryStream
from .judge import llm_judge


def parameter_adaption(memory_stream: MemoryStream, default: float = 4.0, expand_ratio: float = 1.5, min_step: float = 0.1) -> float:
    u1 = [r["parameter"] for r in memory_stream.query(refusal=True)]
    u2 = [r["parameter"] for r in memory_stream.query(drifted=True)]
    if u1 and u2:
        lo = max(u1)
        hi = min(u2)
        if lo >= hi:
            return max(min((lo + hi) / 2.0, lo + min_step), hi - min_step)
        return (lo + hi) / 2.0
    elif u1:
        return max(u1) * expand_ratio
    elif u2:
        return max(min(u2) / expand_ratio, min_step)
    else:
        return default


@dataclass
class AdaptConfig:
    u0: float = 4.0
    max_iters: int = 20
    eps: float = 0.2
    expand_ratio: float = 1.5
    threshold: float = 4.0
    max_new_tokens: int = 40
    verbose_pipeline: bool = False


@dataclass
class AdaptResult:
    ok: bool
    u: float
    output: str
    iters: int
    reason: str
    memory_tail: List[Dict[str, Any]]


class ToxicityAdaptor:
    def __init__(self, target_model, tokenizer, net, net_device: str):
        self.target_model = target_model
        self.tokenizer = tokenizer
        self.net = net
        self.net_device = net_device

    def _run_target(self, prompt: str, u: float, cfg: AdaptConfig) -> str:
        text = pipeline_on_target_model(
            sentence=prompt,
            delta=float(u),
            threshold=float(cfg.threshold),
            target_model=self.target_model,
            tokenizer=self.tokenizer,
            net=self.net,
            net_device=self.net_device,
            max_tokens=2,
            max_new_tokens=cfg.max_new_tokens,
            capture_mode="token",
            verbose=cfg.verbose_pipeline,
        )
        return text

    def adapt(self, malicious_prompt: str, cfg: Optional[AdaptConfig] = None, memory: Optional[MemoryStream] = None) -> AdaptResult:
        if cfg is None:
            cfg = AdaptConfig()
        if memory is None:
            memory = MemoryStream()
        u = cfg.u0
        lo, hi = None, None
        last_ok = None
        last_output = ""
        for it in range(1, cfg.max_iters + 1):
            output = self._run_target(malicious_prompt, u, cfg)
            judge = llm_judge(malicious_prompt, output)
            refusal = bool(judge["refusal_detected"])
            drifted = bool(judge["subject_changed"])
            memory.add_memory(parameter=u, refusal_detected=refusal, subject_changed=drifted, malicious_prompt=malicious_prompt, model_output=output)
            if refusal:
                lo = max(lo, u) if lo is not None else u
            elif drifted:
                hi = min(hi, u) if hi is not None else u
            else:
                last_ok = u
                return AdaptResult(ok=True, u=u, output=output, iters=it, reason="converged", memory_tail=memory.memory[-10:])
            if (lo is not None) and (hi is not None):
                if hi <= lo + cfg.eps:
                    return AdaptResult(ok=False, u=(lo + hi) / 2.0, output=output, iters=it, reason="interval collapsed", memory_tail=memory.memory[-10:])
                u_next = (lo + hi) / 2.0
            elif lo is not None:
                u_next = lo * cfg.expand_ratio
            elif hi is not None:
                u_next = max(hi / cfg.expand_ratio, cfg.eps)
            else:
                u_next = parameter_adaption(memory, default=cfg.u0, expand_ratio=cfg.expand_ratio, min_step=cfg.eps)
            u = u_next
        return AdaptResult(ok=(last_ok is not None), u=(last_ok if last_ok is not None else u), output=last_output, iters=cfg.max_iters, reason="max_iters", memory_tail=memory.memory[-10:])