from typing import Dict, Optional

import requests
from pydantic import BaseModel, Field

from llm_mcts.data_types import Action
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.tasks.arc.task import ARCProblem
from llm_mcts.tasks.base import Task


class VerifierScorerConfig(BaseModel):
    score_factor: Dict[Action, float] = Field(
        default_factory=lambda: {
            "transform": 1.0,
            "question": 0.0,
            "multi_questions": 0.0,
            "verifier": 1.0,
        }
    )


class VerifierScorer(MCTSScorer):
    URL = "http://localhost:8000/predict_batch"

    def __init__(self, task: Task, config: Optional[VerifierScorerConfig] = None):
        if config is None:
            config = VerifierScorerConfig()

        self.scorer_config = config
        self.task = task

    def get_score(self, node: Node) -> float:
        if node.eval_results is None:
            return 0.0

        assert isinstance(self.task, ARCProblem)
        training_demos = self.task.demos

        # Create test predictions by combining inputs and predicted outputs
        test_predictions = [
            {
                "input": test["input"],
                "output": pred_output.answer,
            }
            for test, pred_output in zip(self.task.tests, node.eval_results)
        ]

        # Send POST request
        response = requests.post(
            self.URL, json=[{"demos": training_demos, "preds": test_predictions}]
        )

        if response.status_code == 200:
            result = response.json()
            score_verifier = result["probabilities"][0][1]
        else:
            print(f"Verifier Prediction Error: {response.status_code}")
            print(response.text)
            score_verifier = 0.0

        # Count the number of True examples
        # TODO: Refactor EvalResults to seperate class
        if node.last_action != "multi_questions":
            score = sum(map(lambda x: x.get_score(), node.eval_results))
        else:
            score = sum(map(lambda x: sum(x.get_score()), node.eval_results))
        return (
            self.scorer_config.score_factor[node.last_action] * score
            + self.scorer_config.score_factor["verifier"] * score_verifier
        )
