from abc import ABC, abstractmethod
import re
import time
import os


class Evaluator(ABC):
    """
    Abstract base class for an evaluator that uses different LLMs for evaluation.
    """

    @abstractmethod
    def _generate_response(self, system_msg, user_msg):
        """
        Generate a response using the specific LLM model with its cost.
        """
        pass
    
    @staticmethod
    @abstractmethod
    def decision_extract(self, last_msg):
        """Extract final decision from the LLM output message"""
        pass

    def evaluate(self, usr_msg, sys_msg, save_path="", first_round_voters=3, second_round_voters=2, prefix = ""):
        def conduct_vote(num_voters, prefix = "first-round"):
            results = []
            total_cost = 0
            for i in range(num_voters):
                response, cost = self._generate_response(sys_msg, usr_msg)
                total_cost += cost
                decision = self.decision_extract(response)
                cnt = 0
                while decision is None:
                    cnt +=1
                    print(f"trying at {prefix} with voter {i} for {cnt} times")
                    if cnt >= 10:
                        decision = "None" # the evaluator failed to generate a formatted decision
                        break
                        
                    if cnt >= 4:
                        print("there might be a connection error, waiting...")
                        time.sleep(10)
                    response, cost = self._generate_response(sys_msg, usr_msg)
                    total_cost += cost
                    decision = self.decision_extract(response)
                
                results.append(decision)
                if save_path:
                    with open(os.path.join(save_path, f"{prefix}_res{i}.txt"), "w") as f:
                        f.write(response)
            return results, total_cost

        total_cost = 0
        # First voting round
        first_round_results, cost1 = conduct_vote(first_round_voters, prefix = "first-round")
        total_cost += cost1
        if all(result == first_round_results[0] for result in first_round_results): # if all votes are the same
            print("aligned vote with {} votes. Votes: {}".format(first_round_voters, first_round_results))
            return first_round_results[0], total_cost

        # Second voting round if needed
        print("discrepancy in vote, more judgment needed: {}".format(first_round_results))

        second_round_results,cost2 = conduct_vote(second_round_voters, prefix = "second-round")
        total_cost += cost2
        total_results = first_round_results + second_round_results
        vote_count = {result: total_results.count(result) for result in set(total_results)}
        
        # if there are >= 2 max votes, conduct another round of voting until a majority vote is reached
        cnt = 0
        while len([result for result in vote_count.values() if result == max(vote_count.values())]) >= 2:
            print("discrepancy in vote, more judgment needed: {}".format(total_results))
            third_round_results, cost3 = conduct_vote(1, prefix="additional-{}".format(cnt))
            total_cost += cost3
            total_results += third_round_results
            vote_count = {result: total_results.count(result) for result in set(total_results)}

        majority_vote = max(vote_count, key=vote_count.get)
        return majority_vote, total_cost

