from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import json
import tqdm

from sae_lens import SAE, HookedSAETransformer

from ..datasets.benchmarks import get_benchmark
from .steerer import FeatureSteerer, SteeringConfig


@dataclass
class EvaluationResult:
    
    benchmark_name: str
    condition: str
    
    accuracy: float
    correct: int
    total: int
    
    predictions: list[str]
    expected: list[str]
    is_correct: list[bool]
    
    steering_config: Optional[dict] = None
    
    generation_params: dict = field(default_factory=dict)
    
    def save(self, path: Path):
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        
        with open(path, "w") as f:
            json.dump({
                "benchmark_name": self.benchmark_name,
                "condition": self.condition,
                "accuracy": self.accuracy,
                "correct": self.correct,
                "total": self.total,
                "predictions": self.predictions,
                "expected": self.expected,
                "is_correct": self.is_correct,
                "steering_config": self.steering_config,
                "generation_params": self.generation_params,
            }, f, indent=2)
    
    @classmethod
    def load(cls, path: Path) -> "EvaluationResult":
        with open(path) as f:
            data = json.load(f)
        return cls(**data)


class BenchmarkEvaluator:
    
    def __init__(
        self,
        model: HookedSAETransformer,
        sae: SAE,
        layer_index: int = 8,
    ):
        self.model = model
        self.sae = sae
        self.layer_index = layer_index
        self.steerer = FeatureSteerer(model, sae)
    
    def evaluate(
        self,
        benchmark_name: str,
        condition: str = "baseline",
        steering_config: Optional[SteeringConfig] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.1,
        top_p: float = 0.95,
        do_sample: bool = True,
        apply_chat_template: bool = True,
        max_samples: Optional[int] = None,
        verbose: bool = True,
    ) -> EvaluationResult:
        if condition == "steered" and steering_config is None:
            raise ValueError("steering_config required for steered condition")
        
        benchmark = get_benchmark(benchmark_name)
        benchmark.load()
        
        samples = list(benchmark)
        if max_samples is not None:
            samples = samples[:max_samples]
        
        predictions = []
        expected = []
        is_correct = []
        
        iterator = samples
        if verbose:
            iterator = tqdm.tqdm(samples, desc=f"Evaluating {benchmark_name} ({condition})")
        
        for sample in iterator:
            prompt = benchmark.format_prompt(sample.question)
            
            if condition == "baseline":
                response = self.steerer.generate_baseline(
                    prompt,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=do_sample,
                    apply_chat_template=apply_chat_template,
                )
            else:
                response = self.steerer.generate_with_steering(
                    prompt,
                    steering_config,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=do_sample,
                    apply_chat_template=apply_chat_template,
                )
            
            correct = benchmark.check_answer(response, sample.expected_answer)
            
            predictions.append(response)
            expected.append(sample.expected_answer)
            is_correct.append(correct)
        
        accuracy = sum(is_correct) / len(is_correct) if is_correct else 0.0
        
        return EvaluationResult(
            benchmark_name=benchmark_name,
            condition=condition,
            accuracy=accuracy,
            correct=sum(is_correct),
            total=len(is_correct),
            predictions=predictions,
            expected=expected,
            is_correct=is_correct,
            steering_config={
                "feature_index": steering_config.feature_index,
                "gamma": steering_config.gamma,
                "max_feature_activation": steering_config.max_feature_activation,
                "layer_index": steering_config.layer_index,
            } if steering_config else None,
            generation_params={
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "do_sample": do_sample,
            },
        )
