from data.dialogue import PairwiseDialogue
from server.llm_server import LLMServer
from .registry import auto_register
from .meta import Action
from cachetools import FIFOCache
import pickle
import os
import re
import random
from .fennec import (
    FennecBranchAction,
    FennecScoringAction,
    FennecPairwiseSolvingAction,
    FennecPairwiseMergeAction
)


@auto_register("fennec_beam_branch")
class FennecBeamBranchAction(FennecBranchAction):
    action_name = "fennec_beam_branch"
    
    def __init__(self, config, llm_server: LLMServer) -> None:
        super().__init__(config, llm_server)

        self.max_branch = 10
        self.beam_size = 1

    def execute(self, **action_input):
        action_feedback = {"beam_result": [], "beam_branch_list": []}
        
        for i in range(self.beam_size):
            result = self.get_result_from_server(
                action_input["dialogue"],
                action_input["context"],
                action_input["server"],
                action_input["eval_model"],
            )
            action_feedback["beam_result"].append(result)
            action_feedback["beam_branch_list"].extend(self.extract_branch_list(result))
        return action_feedback


@auto_register("fennec_beam_scoring")
class FennecBeamScoringAction(FennecScoringAction):
    action_name = "fennec_beam_scoring"

    def __init__(self, config, llm_server: LLMServer) -> None:
        super().__init__(config, llm_server)

        self.beam_size = 2

    def execute(self, **action_input):
        action_feedback = {"beam_result": []}
        
        for branch in action_input['branch']["beam_branch_list"]:
            beam = []
            for i in range(self.beam_size):
                result = self.get_result_from_server(
                    action_input["dialogue"],
                    action_input["server"],
                    action_input["eval_model"],
                    branch,
                )
                result = result.replace("[Scoring Guideline]:", "")
                beam.append(result)
            action_feedback["beam_result"].append(beam)
        return action_feedback

@auto_register("fennec_beam_pairwise_solving")
class FennecBeamPairwiseSolvingAction(FennecPairwiseSolvingAction):
    action_name = "fennec_beam_pairwise_solving"

    def __init__(self, config, llm_server: LLMServer) -> None:
        super().__init__(config, llm_server)

        self.beam_size = 2
        
    def execute(self, **action_input):
        
        action_feedback = {
            "beam_result": [],
            "beam_rating_a": [],
            "beam_rating_b": [],
            "beam_ex_result": [],
            "beam_ex_rating_a": [],
            "beam_ex_rating_b": [],
        }
        
        for idx, branch in enumerate(action_input["branch"]["beam_branch_list"]):
            branch = branch.strip()
            
            branch_beam = []
            branch_beam_rating_a = []
            branch_beam_rating_b = []
            branch_beam_ex = []
            branch_beam_ex_rating_a = []
            branch_beam_ex_rating_b = []
            for scoring in action_input["scoring"]["beam_result"][idx]:
                scoring = scoring.strip()

                scoring_beam = []
                scoring_beam_rating_a = []
                scoring_beam_rating_b = []
                scoring_beam_ex = []
                scoring_beam_ex_rating_a = []
                scoring_beam_ex_rating_b = []
                for i in range(self.beam_size):
                    result = self.get_result_from_server(
                        action_input["dialogue"],
                        action_input["server"],
                        action_input["eval_model"],
                        action_input["context"],
                        branch,
                        scoring,
                    )
                    scoring_beam.append(result)
                    rating_a, rating_b = self.rating_format(result, action_input["eval_model"])
                    scoring_beam_rating_a.append(rating_a)
                    scoring_beam_rating_b.append(rating_b)
                    if "eval" in action_input and action_input["eval"]:
                        ex_result = self.get_result_from_server(
                            action_input["dialogue"],
                            action_input["server"],
                            action_input["eval_model"],
                            action_input["context"],
                            branch,
                            scoring,
                            exchange=True,
                        )
                        ex_rating_a, ex_rating_b = self.rating_format(
                            ex_result, action_input["eval_model"])
                        scoring_beam_ex.append(ex_result)
                        scoring_beam_ex_rating_a.append(ex_rating_a)
                        scoring_beam_ex_rating_b.append(ex_rating_b)
                branch_beam.append(scoring_beam)
                branch_beam_rating_a.append(scoring_beam_rating_a)  
                branch_beam_rating_b.append(scoring_beam_rating_b)
                branch_beam_ex.append(scoring_beam_ex)
                branch_beam_ex_rating_a.append(scoring_beam_ex_rating_a)
                branch_beam_ex_rating_b.append(scoring_beam_ex_rating_b)

            action_feedback["beam_result"].append(branch_beam)
            action_feedback["beam_rating_a"].append(branch_beam_rating_a)
            action_feedback["beam_rating_b"].append(branch_beam_rating_b)
            action_feedback["beam_ex_result"].append(branch_beam_ex)
            action_feedback["beam_ex_rating_a"].append(branch_beam_ex_rating_a)
            action_feedback["beam_ex_rating_b"].append(branch_beam_ex_rating_b)  
        return action_feedback


@auto_register("fennec_beam_pairwise_merge")
class FennecBeamPairwiseMergeAction(FennecPairwiseMergeAction):
    action_name = "fennec_beam_pairwise_merge"
    
    def __init__(self, config, llm_server: LLMServer) -> None:
        super().__init__(config, llm_server)
        
        self.max_candidate = 40
    
    def sum_tree(self, score, feedback):
        
        for a, b, ex_a, ex_b in zip(score['model_a'], score['model_b'], score['ex_model_a'], score['ex_model_b']):
            feedback['model_a'] += sum(a)
            feedback['model_b'] += sum(b)
            feedback['ex_model_a'] += sum(ex_a)
            feedback['ex_model_b'] += sum(ex_b)

    def mean_tree(self, score, feedback):
        
        a_list = []
        b_list = []
        ex_a_list = []
        ex_b_list = []
        for a, b, ex_a, ex_b in zip(score['model_a'], score['model_b'], score['ex_model_a'], score['ex_model_b']):
            if a[0] == a[1]: a_list.append(sum(a) / 2)
            if b[0] == b[1]: b_list.append(sum(b) / 2)
            if ex_a[0] == ex_a[1]: ex_a_list.append(sum(ex_a) / 2)
            if ex_b[0] == ex_b[1]: ex_b_list.append(sum(ex_b) / 2)
        feedback['model_a'] = sum(a_list) / len(a_list)
        feedback['model_b'] = sum(b_list) / len(b_list)
        feedback['ex_model_a'] = sum(ex_a_list) / len(ex_a_list)
        feedback['ex_model_b'] = sum(ex_b_list) / len(ex_b_list)

    def group_tree(self, score, feedback):
        a_group = {"win": 0, "lose": 0, "tie": 0}
        b_group = {"win": 0, "lose": 0, "tie": 0}
        length = {"win": 0, "lose": 0, "tie": 0}
        ex_a_group = {"win": 0, "lose": 0, "tie": 0}
        ex_b_group = {"win": 0, "lose": 0, "tie": 0}
        ex_length = {"win": 0, "lose": 0, "tie": 0}
        for a, b, ex_a, ex_b in zip(score['model_a'], score['model_b'], score['ex_model_a'], score['ex_model_b']):
            if sum(a) > sum(b):
                a_group["win"] += sum(a)
                b_group["win"] += sum(b)
                length["win"] += 1
            elif sum(a) < sum(b):
                a_group["lose"] += sum(a)
                b_group["lose"] += sum(b)
                length["lose"] += 1
            else:
                a_group["tie"] += sum(a)
                b_group["tie"] += sum(b)
                length["tie"] += 1
            
            if sum(ex_a) > sum(ex_b):
                ex_a_group["win"] += sum(ex_a)
                ex_b_group["win"] += sum(ex_b)
                ex_length["win"] += 1
            elif sum(ex_a) < sum(ex_b):
                ex_a_group["lose"] += sum(ex_a)
                ex_b_group["lose"] += sum(ex_b)
                ex_length["lose"] += 1
            else:
                ex_a_group["tie"] += sum(ex_a)
                ex_b_group["tie"] += sum(ex_b)
                ex_length["tie"] += 1

        feedback['model_a'] = a_group[max(length, key=length.get)]
        feedback['model_b'] = b_group[max(length, key=length.get)]
        feedback['ex_model_a'] = ex_a_group[max(ex_length, key=ex_length.get)]
        feedback['ex_model_b'] = ex_b_group[max(ex_length, key=ex_length.get)]
    
    def execute(self, **action_input):
        action_feedback = {"model_a": 0, "model_b": 0, "ex_model_a": 0, "ex_model_b": 0}
        idx = 0
        
        score = {"model_a": [], "model_b": [], "ex_model_a": [], "ex_model_b": []}
        judge = action_input["judge"][0]
        if "solving" in action_input:
            for beam_branch_a, beam_branch_b, beam_ex_branch_a, beam_ex_branch_b in zip(
                action_input['solving']['beam_rating_a'],
                action_input['solving']['beam_rating_b'],
                action_input["solving"]["beam_ex_rating_a"],
                action_input["solving"]["beam_ex_rating_b"],
            ):
                for beam_scoring_a, beam_scoring_b, beam_ex_scoring_a, beam_ex_scoring_b in zip(beam_branch_a, beam_branch_b, beam_ex_branch_a, beam_ex_branch_b):
                    # import pdb; pdb.set_trace()
                    # for a, b, ex_a, ex_b in zip(beam_scoring_a, beam_scoring_b, beam_ex_scoring_a, beam_ex_scoring_b):
                    #     if idx < self.max_candidate:
                    #         action_feedback["model_a"] += a
                    #         action_feedback["model_b"] += b
                    #         action_feedback["ex_model_a"] += ex_a
                    #         action_feedback["ex_model_b"] += ex_b
                    #     idx += 1
                    if idx < self.max_candidate:
                        score["model_a"].append(beam_scoring_a)
                        score["model_b"].append(beam_scoring_b)
                        score["ex_model_a"].append(beam_ex_scoring_a)
                        score["ex_model_b"].append(beam_ex_scoring_b)
                    idx += 1
        self.sum_tree(score, action_feedback)
        return action_feedback
