import uuid
import random
from typing import Any, Dict, List, Tuple, Optional
from pathlib import Path
from agents.base.agent import AgentBase
from exp.base.base import ExperimentBase
from exp.utils.registry import register_experiment
from exp.utils.datatypes import ExperimentMetrics
from utils.parse_output import OutputStringParser

@register_experiment("simple_sum")
class SimpleSumExperiment(ExperimentBase):
    def __init__(self, task: str = "simple_sum", num_test: int = 10, logs_dir: Optional[Path] = None, agent: Optional[AgentBase] = None, **kwargs):
        super().__init__(task=task, num_test=num_test, logs_dir=logs_dir, agent=agent)

    def prepare_data(self) -> None:
        pass

    def data_iterator(self):
        for _ in range(self.num_test):
            # Generate two random numbers
            a = random.randint(1, 100)
            b = random.randint(1, 100)
            yield {
                "subject_id": str(uuid.uuid4()),
                "data": [a, b],
                "GT": {"sum": a + b}
            }

    async def run_agent(self, data: Dict[str, Any]) -> Dict[str, Any]:
        query_id = str(uuid.uuid4())
        prompt = "Sum the numbers in the data list."
        
        # Pass the data to the agent
        agent_input = {"data": data["data"]}
        
        response = await self.agent.query(prompt, agent_input, self.logs_dir, query_id)
        
        solution, fail_reason = self.parse_output(response)
        
        return {
            "query_id": query_id,
            "subject_id": data["subject_id"],
            "GT": data["GT"],
            "solution": solution,
            "fail_reason": fail_reason
        }

    def parse_output(self, content: str, query_id: Optional[str] = None) -> Tuple[Dict[str, Any], Any]:
        return OutputStringParser.parse_dict(
            content,
            expected_keys=["sum"],
            expected_value_types={"sum": int} # or float
        )

    def calculate_metrics(self, result_list: List[dict]) -> ExperimentMetrics:
        correct = 0
        failures = 0
        total = 0
        
        for res in result_list:
            total += 1
            if res.get("fail_reason"):
                failures += 1
                continue
                
            gt = res["GT"]["sum"]
            sol = res["solution"].get("sum")
            
            if sol == gt:
                correct += 1
                
        accuracy = correct / total if total > 0 else 0
        
        return {
            "Accuracy": accuracy,
            "Failures": failures,
            "additional_metrics": {}
        }

    def save_data(self, data: Dict[str, Any], query_id: Optional[str] = None) -> None:
        pass
