import json
import pdb
import os
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
from data.eval_event import EvalEvent
from .meta import Evaluation
from .registry import auto_register
from actions.fennec_v2 import (
    FennecBeamBranchAction,
    FennecBeamScoringAction,
    FennecBeamPairwiseSolvingAction,
    FennecBeamPairwiseMergeAction,
)
from config.task_config import TaskConfig
from scipy.stats import kendalltau
from scipy.stats import spearmanr
from scipy.stats import pearsonr
import threading
import numpy as np


write_lock = threading.Lock()


@auto_register("fennec_v2")
class TreeFennecEvaluation(Evaluation):
    task_name = "fennec_v2"

    def __init__(self, config: TaskConfig, action_runner) -> None:
        super().__init__(config)

        self.action_runner = action_runner
        
        self.train_sft_file = self.config.get_train_sft_file(self.task_name)
        self.train_dpo_file = self.config.get_train_dpo_file(self.task_name)
        
        self.train_sft = open(self.train_sft_file, 'w')
        self.train_dpo = open(self.train_dpo_file, 'w')
        
        self.eval_score = {}
        self.eval_sft = {}
        self.eval_dpo = {}
        
    def eval(self, eval_event: EvalEvent):
        dialogue = eval_event.get_dialogue()
        
        if self.action_runner == "pairwise_eval":
            judge_prediction = self.pairwise_eval_score(dialogue, eval_event) 
            self.pairwise_eval_sft_gen(dialogue, eval_event, judge_prediction)
            # self.pairwise_eval_dpo_gen(dialogue, eval_event, judge_prediction)
            
    def pairwise_eval_score(self, dialogue, eval_event):
        meta_info = dialogue.get_meta_info()
        turn = meta_info["turn"] - 1

        fpm_action_feedback = eval_event.get_memories(
            self.task_name,
            FennecBeamPairwiseMergeAction.action_name,
            str("turn{}".format(turn)),
        )
        
        judge = meta_info["judge"][0]
        judge_prediction = 2
        
        if isinstance(judge, str):
            if judge == "model_a":
                judge = 0
            elif judge == "model_b":
                judge = 1
            elif "tie" in judge:
                judge = 2
        elif len(meta_info["judge"]) == 3:
            # panda lm
            if judge == 0:
                judge = 2
            elif judge == 1:
                judge = 0
            else:
                judge = 1
        
        if "agreement" not in self.eval_score:
            self.eval_score["agreement"] = []
            self.eval_score["consistency"] = []
            self.eval_score["single_agreement"] = []
            self.eval_score["error"] = 0

            self.eval_score["g_win"] = 0
            self.eval_score["win"] = 0
            self.eval_score["g_lose"] = 0
            self.eval_score["lose"] = 0
            self.eval_score["g_tie"] = 0
            self.eval_score["tie"] = 0
        
        if judge == 0:
            self.eval_score["g_win"] += 1
        elif judge == 1:
            self.eval_score["g_lose"] += 1
        elif judge == 2:
            self.eval_score["g_tie"] += 1
        
        if fpm_action_feedback["model_a"] > fpm_action_feedback["model_b"]:
            self.eval_score["win"] += 1
            judge_prediction = 0
        elif fpm_action_feedback["model_a"] < fpm_action_feedback["model_b"]:
            self.eval_score["lose"] += 1
            judge_prediction = 1
        elif fpm_action_feedback["model_a"] == fpm_action_feedback["model_b"]:
            self.eval_score["tie"] += 1
            judge_prediction = 2

        # if judge == 2:
        #     return
        
        if (
            fpm_action_feedback["model_a"] + fpm_action_feedback["ex_model_a"] > fpm_action_feedback["model_b"] + fpm_action_feedback["ex_model_b"]
            and judge == 0
        ):
            self.eval_score["single_agreement"].append(1)
        elif (
            fpm_action_feedback["model_a"] + fpm_action_feedback["ex_model_a"] < fpm_action_feedback["model_b"] + fpm_action_feedback["ex_model_b"]
            and judge == 1
        ):
            self.eval_score["single_agreement"].append(1)
        elif (
            fpm_action_feedback["model_a"] + fpm_action_feedback["ex_model_a"] == fpm_action_feedback["model_b"] + fpm_action_feedback["ex_model_b"]
            and judge == 2
        ):
            self.eval_score["single_agreement"].append(1)
        else:
            self.eval_score["single_agreement"].append(0)

        if (
            fpm_action_feedback["model_a"] > fpm_action_feedback["model_b"]
            and fpm_action_feedback["ex_model_a"]
            > fpm_action_feedback["ex_model_b"]
        ):
            if judge == 0:
                self.eval_score["agreement"].append(1)
            else:
                self.eval_score["agreement"].append(0)
            self.eval_score["consistency"].append(1)
        elif (
            fpm_action_feedback["model_a"] < fpm_action_feedback["model_b"]
            and fpm_action_feedback["ex_model_a"]
            < fpm_action_feedback["ex_model_b"]
        ):
            if judge == 1:
                self.eval_score["agreement"].append(1)
            else:
                self.eval_score["agreement"].append(0)
            self.eval_score["consistency"].append(1)
        elif (
            fpm_action_feedback["model_a"] == fpm_action_feedback["model_b"]
            and fpm_action_feedback["ex_model_a"]
            == fpm_action_feedback["ex_model_b"]
        ):
            if judge == 2:
                self.eval_score["agreement"].append(1)
            else:
                self.eval_score["agreement"].append(0)
            self.eval_score["consistency"].append(1)
        else:
            if (fpm_action_feedback["model_a"] == 0 or fpm_action_feedback["model_b"] == 0):
                self.eval_score["error"] += 1
            
            self.eval_score["consistency"].append(0)
            self.eval_score["agreement"].append(0)
        
        return judge
    
    def split_judgment(self, judge1_a, judge1_b, judge2_a, judge2_b, branch, scoring, judgment1, judgment2, judge):
        
        consistency = []
        inconsistency = []
        
        if (judge1_a > judge1_b and judge2_a > judge2_b and judge == 0) or (judge1_a < judge1_b and judge2_a < judge2_b and judge == 1) or (judge1_a == judge1_b and judge2_a == judge2_b and judge == 2):
            consistency.append({
                "branch": branch,
                "scoring": scoring,
                "judgment_chosen": judgment1,
                "judgment_rejected": judgment2,
            })
        else:
            if judge == 0:
                if judge1_a > judge1_b and judge2_a <= judge2_b:
                    inconsistency.append({
                        "branch": branch,
                        "scoring": scoring,
                        "judgment_chosen": judgment1,
                        "judgment_rejected": judgment2,
                    })
                elif judge2_a > judge2_b and judge1_a <= judge1_b:
                    inconsistency.append({
                        "branch": branch,
                        "scoring": scoring,
                        "judgment_chosen": judgment2,
                        "judgment_rejected": judgment1,
                    })
            elif judge == 1:
                if judge1_a < judge1_b and judge2_a >= judge2_b:
                    inconsistency.append({
                        "branch": branch,
                        "scoring": scoring,
                        "judgment_chosen": judgment1,
                        "judgment_rejected": judgment2,
                    })
                elif judge2_a < judge2_b and judge1_a >= judge1_b:
                    inconsistency.append({
                        "branch": branch,
                        "scoring": scoring,
                        "judgment_chosen": judgment2,
                        "judgment_rejected": judgment1,
                    })
            elif judge == 2:
                if judge1_a == judge1_b and judge2_a != judge2_b:
                    inconsistency.append({
                        "branch": branch,
                        "scoring": scoring,
                        "judgment_chosen": judgment1,
                        "judgment_rejected": judgment2,
                    })
                elif judge2_a == judge2_b and judge1_a != judge1_b:
                    inconsistency.append({
                        "branch": branch,
                        "scoring": scoring,
                        "judgment_chosen": judgment2,
                        "judgment_rejected": judgment1,
                    })
        return consistency, inconsistency

    def split_eval_branch(self, fb_action_feedback, fs_action_feedback, fps_action_feedback, judge):
        
        scoring_self_consistency = {"consistency": [], "inconsistency": []}
        judgment_self_consistency = {"consistency": [], "inconsistency": []}
        judgment_exchange_consistency = {"consistency": [], "inconsistency": []}
        
        branch_idx = 0
        for beam_branch_a, beam_branch_b, beam_ex_branch_a, beam_ex_branch_b in zip(
            fps_action_feedback['beam_rating_a'],
            fps_action_feedback['beam_rating_b'],
            fps_action_feedback["beam_ex_rating_a"],
            fps_action_feedback["beam_ex_rating_b"],
        ):
            branch = fb_action_feedback['beam_branch_list'][branch_idx]
            
            scoring_1 = fs_action_feedback['beam_result'][branch_idx][0]
            scoring_2 = fs_action_feedback['beam_result'][branch_idx][1]
            
            # Scoring self-consistency
            s1_a = sum(beam_branch_a[0]) + sum(beam_ex_branch_a[0])
            s1_b = sum(beam_branch_b[0]) + sum(beam_ex_branch_b[0])
            
            s2_a = sum(beam_branch_a[1]) + sum(beam_ex_branch_a[1])
            s2_b = sum(beam_branch_b[1]) + sum(beam_ex_branch_b[1])
            if (s1_a > s1_b and s2_a > s2_b and judge == 0) or (s1_a < s1_b and s2_a < s2_b and judge == 1) or (s1_a == s1_b and s2_a == s2_b and judge == 2):
                scoring_self_consistency["consistency"].append({
                    "branch": branch,
                    "scoring_chosen": scoring_1,
                    "scoring_rejected": scoring_2,
                })
            else:
                if judge == 0:
                    if s1_a > s1_b:
                        scoring_self_consistency["inconsistency"].append({
                            "branch": branch,
                            "scoring_chosen": scoring_1,
                            "scoring_rejected": scoring_2,
                        })
                    elif s2_a > s2_b:
                        scoring_self_consistency["inconsistency"].append({
                            "branch": branch,
                            "scoring_chosen": scoring_2,
                            "scoring_rejected": scoring_1,
                        })
                elif judge == 1:
                    if s1_a < s1_b:
                        scoring_self_consistency["inconsistency"].append({
                            "branch": branch,
                            "scoring_chosen": scoring_1,
                            "scoring_rejected": scoring_2,
                        })
                    elif s2_a < s2_b:
                        scoring_self_consistency["inconsistency"].append({
                            "branch": branch,
                            "scoring_chosen": scoring_2,
                            "scoring_rejected": scoring_1,
                        })
                elif judge == 2:
                    if s1_a == s1_b:
                        scoring_self_consistency["inconsistency"].append({
                            "branch": branch,
                            "scoring_chosen": scoring_1,
                            "scoring_rejected": scoring_2,
                        })
                    elif s2_a == s2_b:
                        scoring_self_consistency["inconsistency"].append({
                            "branch": branch,
                            "scoring_chosen": scoring_2,
                            "scoring_rejected": scoring_1,
                        })

            judgment_11 = fps_action_feedback['beam_result'][branch_idx][0][0]
            judgment_12 = fps_action_feedback['beam_result'][branch_idx][0][1]
            judgment_21 = fps_action_feedback['beam_result'][branch_idx][1][0]
            judgment_22 = fps_action_feedback['beam_result'][branch_idx][1][1]
            ex_judgment_11 = fps_action_feedback['beam_ex_result'][branch_idx][0][0]
            ex_judgment_12 = fps_action_feedback['beam_ex_result'][branch_idx][0][1]
            ex_judgment_21 = fps_action_feedback['beam_ex_result'][branch_idx][1][0]
            ex_judgment_22 = fps_action_feedback['beam_ex_result'][branch_idx][1][1]
            
            # Judgment self-consistency 1
            j11_a = beam_branch_a[0][0]
            j11_b = beam_branch_b[0][0]
            j12_a = beam_branch_a[0][1]
            j12_b = beam_branch_b[0][1]
            consistency_list, inconsistency_list = self.split_judgment(j11_a, j11_b, j12_a, j12_b, branch, scoring_1, judgment_11, judgment_12, judge)
            judgment_self_consistency["consistency"].extend(consistency_list)
            judgment_self_consistency["inconsistency"].extend(inconsistency_list)
            
            # Judgment self-consistency 2
            j21_a = beam_branch_a[1][0]
            j21_b = beam_branch_b[1][0]
            j22_a = beam_branch_a[1][1]
            j22_b = beam_branch_b[1][1]
            consistency_list, inconsistency_list = self.split_judgment(j21_a, j21_b, j22_a, j22_b, branch, scoring_2, judgment_21, judgment_22, judge)
            judgment_self_consistency["consistency"].extend(consistency_list)
            judgment_self_consistency["inconsistency"].extend(inconsistency_list)
            
            # Judgment exchange self-consistency 3
            ex_j11_a = beam_ex_branch_a[0][0]
            ex_j11_b = beam_ex_branch_b[0][0]
            ex_j12_a = beam_ex_branch_a[0][1]
            ex_j12_b = beam_ex_branch_b[0][1]
            consistency_list, inconsistency_list = self.split_judgment(ex_j11_a, ex_j11_b, ex_j12_a, ex_j12_b, branch, scoring_1, ex_judgment_11, ex_judgment_12, judge)
            judgment_self_consistency["consistency"].extend(consistency_list)
            judgment_self_consistency["inconsistency"].extend(inconsistency_list)
            
            # Judgment self-consistency 4
            ex_j21_a = beam_ex_branch_a[1][0]
            ex_j21_b = beam_ex_branch_b[1][0]
            ex_j22_a = beam_ex_branch_a[1][1]
            ex_j22_b = beam_ex_branch_b[1][1]
            consistency_list, inconsistency_list = self.split_judgment(ex_j21_a, ex_j21_b, ex_j22_a, ex_j22_b, branch, scoring_2, ex_judgment_21, ex_judgment_22, judge)
            judgment_self_consistency["consistency"].extend(consistency_list)
            judgment_self_consistency["inconsistency"].extend(inconsistency_list)

            # J1 J2 exchange-consistency
            j11_a = beam_branch_a[0][0]
            j11_b = beam_branch_b[0][0]
            ex_j11_a = beam_ex_branch_a[0][0]
            ex_j11_b = beam_ex_branch_b[0][0]
            consistency_list, inconsistency_list = self.split_judgment(j11_a, j11_b, ex_j11_a, ex_j11_b, branch, scoring_1, judgment_11, ex_judgment_11, judge)
            judgment_exchange_consistency["consistency"].extend(consistency_list)
            judgment_exchange_consistency["inconsistency"].extend(inconsistency_list)
            
            j11_a = beam_branch_a[0][0]
            j11_b = beam_branch_b[0][0]
            ex_j12_a = beam_ex_branch_a[0][1]
            ex_j12_b = beam_ex_branch_b[0][1]
            consistency_list, inconsistency_list = self.split_judgment(j11_a, j11_b, ex_j12_a, ex_j12_b, branch, scoring_1, judgment_11, ex_judgment_12, judge)
            judgment_exchange_consistency["consistency"].extend(consistency_list)
            judgment_exchange_consistency["inconsistency"].extend(inconsistency_list)

            j12_a = beam_branch_a[0][1]
            j12_b = beam_branch_b[0][1]
            ex_j11_a = beam_ex_branch_a[0][0]
            ex_j11_b = beam_ex_branch_b[0][0]
            consistency_list, inconsistency_list = self.split_judgment(j12_a, j12_b, ex_j11_a, ex_j11_b, branch, scoring_1, judgment_12, ex_judgment_11, judge)
            judgment_exchange_consistency["consistency"].extend(consistency_list)
            judgment_exchange_consistency["inconsistency"].extend(inconsistency_list)
            
            j12_a = beam_branch_a[0][1]
            j12_b = beam_branch_b[0][1]
            ex_j12_a = beam_ex_branch_a[0][1]
            ex_j12_b = beam_ex_branch_b[0][1]
            consistency_list, inconsistency_list = self.split_judgment(j12_a, j12_b, ex_j12_a, ex_j12_b, branch, scoring_1, judgment_12, ex_judgment_12, judge)
            judgment_exchange_consistency["consistency"].extend(consistency_list)
            judgment_exchange_consistency["inconsistency"].extend(inconsistency_list)
            
            j21_a = beam_branch_a[1][0]
            j21_b = beam_branch_b[1][0]
            ex_j21_a = beam_ex_branch_a[1][0]
            ex_j21_b = beam_ex_branch_b[1][0]
            consistency_list, inconsistency_list = self.split_judgment(j21_a, j21_b, ex_j21_a, ex_j21_b, branch, scoring_2, judgment_21, ex_judgment_21, judge)
            judgment_exchange_consistency["consistency"].extend(consistency_list)
            judgment_exchange_consistency["inconsistency"].extend(inconsistency_list)

            j21_a = beam_branch_a[1][0]
            j21_b = beam_branch_b[1][0]
            ex_j22_a = beam_ex_branch_a[1][1]
            ex_j22_b = beam_ex_branch_b[1][1]
            consistency_list, inconsistency_list = self.split_judgment(j21_a, j21_b, ex_j22_a, ex_j22_b, branch, scoring_2, judgment_21, ex_judgment_22, judge)
            judgment_exchange_consistency["consistency"].extend(consistency_list)
            judgment_exchange_consistency["inconsistency"].extend(inconsistency_list)

            j22_a = beam_branch_a[1][1]
            j22_b = beam_branch_b[1][1]
            ex_j21_a = beam_ex_branch_a[1][0]
            ex_j21_b = beam_ex_branch_b[1][0]
            consistency_list, inconsistency_list = self.split_judgment(j22_a, j22_b, ex_j21_a, ex_j21_b, branch, scoring_2, judgment_22, ex_judgment_21, judge)
            judgment_exchange_consistency["consistency"].extend(consistency_list)
            judgment_exchange_consistency["inconsistency"].extend(inconsistency_list)
            
            j22_a = beam_branch_a[1][1]
            j22_b = beam_branch_b[1][1]
            ex_j22_a = beam_ex_branch_a[1][1]
            ex_j22_b = beam_ex_branch_b[1][1]
            consistency_list, inconsistency_list = self.split_judgment(j22_a, j22_b, ex_j22_a, ex_j22_b, branch, scoring_2, judgment_22, ex_judgment_22, judge)
            judgment_exchange_consistency["consistency"].extend(consistency_list)
            judgment_exchange_consistency["inconsistency"].extend(inconsistency_list)

            branch_idx += 1
        return scoring_self_consistency, judgment_self_consistency, judgment_exchange_consistency
            
    def pairwise_eval_sft_gen(self, dialogue, eval_event, judge_prediction):
        meta_info = dialogue.get_meta_info()
        turn = meta_info["turn"] - 1
        
        query = dialogue.get_query_by_idx(0)["content"]
        response_1 = dialogue.get_pairwise_response_by_idx(0, "model_a")["content"]
        response_2 = dialogue.get_pairwise_response_by_idx(0, "model_b")["content"]

        fb_action_feedback = eval_event.get_memories(
            self.task_name,
            FennecBeamBranchAction.action_name,
            str("turn{}".format(turn)),
        )
        fs_action_feedback = eval_event.get_memories(
            self.task_name,
            FennecBeamScoringAction.action_name,
            str("turn{}".format(turn)),
        )
        fps_action_feedback = eval_event.get_memories(
            self.task_name,
            FennecBeamPairwiseSolvingAction.action_name,
            str("turn{}".format(turn)),
        )
        
        if "count" not in self.eval_sft:
            self.eval_sft["table"] = None
            
            self.eval_sft["count"] = 0
            
            self.eval_sft["branch"] = 0
            self.eval_sft["score_self_consistency"] = 0
            self.eval_sft["self_consistency"] = 0
            self.eval_sft["ex_consistency"] = 0
                    
        # add branch
        new_data = {
            "query": [query],
            "branch": ["\n".join(fb_action_feedback['beam_branch_list'])],
            "branch_chosen": [""],
            "branch_rejected": [""],
            "cur_branch": [""],
            "scoring": [""],
            "scoring_chosen": [""],
            "scoring_rejected": [""],
            "judgment_chosen": [""],
            "judgment_rejected": [""],
            "response_1": [response_1],
            "response_2": [response_2],
            "judge": [judge_prediction],
            "context": [meta_info["context"]] if "context" in meta_info and meta_info["context"] else [[""]]
        }
        self.train_sft.writelines(json.dumps(new_data) + "\n")
        self.eval_sft["branch"] += 1
        
        # table = pa.Table.from_pandas(pd.DataFrame(new_data))
        # if self.eval_sft["table"] is None:
        #     self.eval_sft["table"] = table
        # else:
        #     self.eval_sft["table"] = pa.concat_tables(
        #         [self.eval_sft["table"], table]
        #     )
        
        scoring_self_consistency, judgment_self_consistency, judgment_exchange_consistency = self.split_eval_branch(fb_action_feedback, fs_action_feedback, fps_action_feedback, judge_prediction)

        for item in scoring_self_consistency["consistency"]:
            new_data  = {
                "query": [query],
                "branch": ["\n".join(fb_action_feedback['beam_branch_list'])],
                "branch_chosen": [""],
                "branch_rejected": [""],
                "cur_branch": [item["branch"]],
                "scoring": [""],
                "scoring_chosen": [item["scoring_chosen"]],
                "scoring_rejected": [item["scoring_rejected"]],
                "judgment_chosen": [""],
                "judgment_rejected": [""],
                "response_1": [response_1],
                "response_2": [response_2],
                "judge": [judge_prediction],
                "context": [meta_info["context"]] if "context" in meta_info and meta_info["context"] else [[""]]
            }

            self.train_sft.writelines(json.dumps(new_data) + "\n")
            # table = pa.Table.from_pandas(new_data)
            # self.eval_sft["table"] = pa.concat_tables(
            #     [self.eval_sft["table"], table]
            # )
            self.eval_sft["score_self_consistency"] += 1
            
        for item in judgment_self_consistency["consistency"]:
            new_data = {
                "query": [query],
                "branch": ["\n".join(fb_action_feedback['beam_branch_list'])],
                "branch_chosen": [""],
                "branch_rejected": [""],
                "cur_branch": [item["branch"]],
                "scoring": [item["scoring"]],
                "scoring_chosen": [""],
                "scoring_rejected": [""],
                "judgment_chosen": [item["judgment_chosen"]],
                "judgment_rejected": [item["judgment_rejected"]],
                "response_1": [response_1],
                "response_2": [response_2],
                "judge": [judge_prediction],
                "context": [meta_info["context"]] if "context" in meta_info and meta_info["context"] else [[""]]
            }
            self.train_sft.writelines(json.dumps(new_data) + "\n")
            # table = pa.Table.from_pandas(new_data)
            # self.eval_sft["table"] = pa.concat_tables(
            #     [self.eval_sft["table"], table]
            # ) 
            self.eval_sft["self_consistency"] += 1
        
        for item in judgment_exchange_consistency["consistency"]:
            new_data = {
                "query": [query],
                "branch": ["\n".join(fb_action_feedback['beam_branch_list'])],
                "branch_chosen": [""],
                "branch_rejected": [""],
                "cur_branch": [item["branch"]],
                "scoring": [item["scoring"]],
                "scoring_chosen": [""],
                "scoring_rejected": [""],
                "judgment_chosen": [item["judgment_chosen"]],
                "judgment_rejected": [item["judgment_rejected"]],
                "response_1": [response_1],
                "response_2": [response_2],
                "judge": [judge_prediction],
                "context": [meta_info["context"]] if "context" in meta_info and meta_info["context"] else [[""]]
            }
            self.train_sft.writelines(json.dumps(new_data) + "\n")
            # table = pa.Table.from_pandas(new_data)
            # self.eval_sft["table"] = pa.concat_tables(
            #     [self.eval_sft["table"], table]
            # ) 
            self.eval_sft["ex_consistency"] += 1

    def pairwise_eval_dpo_gen(self, dialogue, eval_event, judge_prediction):
        meta_info = dialogue.get_meta_info()
        turn = meta_info["turn"] - 1
        
        query = dialogue.get_query_by_idx(0)["content"]
        response_1 = dialogue.get_pairwise_response_by_idx(0, "model_a")["content"]
        response_2 = dialogue.get_pairwise_response_by_idx(0, "model_b")["content"]

        fb_action_feedback = eval_event.get_memories(
            self.task_name,
            FennecBeamBranchAction.action_name,
            str("turn{}".format(turn)),
        )
        fs_action_feedback = eval_event.get_memories(
            self.task_name,
            FennecBeamScoringAction.action_name,
            str("turn{}".format(turn)),
        )
        fps_action_feedback = eval_event.get_memories(
            self.task_name,
            FennecBeamPairwiseSolvingAction.action_name,
            str("turn{}".format(turn)),
        )
        
        if "count" not in self.eval_dpo:
            self.eval_dpo["table"] = None
            
            self.eval_dpo["count"] = 0
            
            self.eval_dpo["branch"] = 0
            self.eval_dpo["score_self_consistency"] = 0
            self.eval_dpo["self_consistency"] = 0
            self.eval_dpo["ex_consistency"] = 0

        branch_dict = {}
        branch_idx = 0
        for beam_branch_a, beam_branch_b, beam_ex_branch_a, beam_ex_branch_b in zip(
            fps_action_feedback['beam_rating_a'],
            fps_action_feedback['beam_rating_b'],
            fps_action_feedback["beam_ex_rating_a"],
            fps_action_feedback["beam_ex_rating_b"],
        ):
            branch = fb_action_feedback['beam_branch_list'][branch_idx]
            diff = (sum(beam_branch_a[0]) + sum(beam_branch_a[1]) + sum(beam_ex_branch_a[0]) + sum(beam_ex_branch_a[1])) - (sum(beam_branch_b[0]) + sum(beam_branch_b[1]) + sum(beam_ex_branch_b[0]) + sum(beam_ex_branch_b[1]))
            branch_dict[branch] = diff
            branch_idx += 1
        
        branch_chosen = self.branch_sorted(branch_dict, judge_prediction)
        # add branch
        new_data = {
            "query": [query],
            "branch": [""],
            "branch_chosen": ["\n".join(branch_chosen)],
            "branch_rejected": ["\n".join(fb_action_feedback['beam_branch_list'])],
            "cur_branch": [""],
            "scoring": [""],
            "scoring_chosen": [""],
            "scoring_rejected": [""],
            "judgment_chosen": [""],
            "judgment_rejected": [""],
            "response_1": [response_1],
            "response_2": [response_2],
            "judge": [judge_prediction],
            "context": [meta_info["context"]] if "context" in meta_info and meta_info["context"] else [[""]]
        }
        
        self.train_dpo.writelines(json.dumps(new_data) + "\n")
        self.eval_dpo["branch"] += 1
        
        # table = pa.Table.from_pandas(new_data)
        # if self.eval_dpo["table"] is None:
        #     self.eval_dpo["table"] = table
        # else:
        #     self.eval_dpo["table"] = pa.concat_tables(
        #         [self.eval_dpo["table"], table]
        #     )
        
        scoring_self_consistency, judgment_self_consistency, judgment_exchange_consistency = self.split_eval_branch(fb_action_feedback, fs_action_feedback, fps_action_feedback, judge_prediction)

        for item in scoring_self_consistency["inconsistency"]:
            new_data = {
                "query": [query],
                "branch": ["\n".join(fb_action_feedback['beam_branch_list'])],
                "branch_chosen": [""],
                "branch_rejected": [""],
                "cur_branch": [item["branch"]],
                "scoring": [""],
                "scoring_chosen": [item["scoring_chosen"]],
                "scoring_rejected": [item["scoring_rejected"]],
                "judgment_chosen": [""],
                "judgment_rejected": [""],
                "response_1": [response_1],
                "response_2": [response_2],
                "judge": [judge_prediction],
                "context": [meta_info["context"]] if "context" in meta_info and meta_info["context"] else [[""]]
            }
        
            self.train_dpo.writelines(json.dumps(new_data) + "\n")
            # table = pa.Table.from_pandas(new_data)
            # self.eval_dpo["table"] = pa.concat_tables(
            #     [self.eval_dpo["table"], table]
            # )
            self.eval_dpo["score_self_consistency"] += 1
            
        for item in judgment_self_consistency["inconsistency"]:
            new_data = {
                "query": [query],
                "branch": ["\n".join(fb_action_feedback['beam_branch_list'])],
                "branch_chosen": [""],
                "branch_rejected": [""],
                "cur_branch": [item["branch"]],
                "scoring": [item["scoring"]],
                "scoring_chosen": [""],
                "scoring_rejected": [""],
                "judgment_chosen": [item["judgment_chosen"]],
                "judgment_rejected": [item["judgment_rejected"]],
                "response_1": [response_1],
                "response_2": [response_2],
                "judge": [judge_prediction],
                "context": [meta_info["context"]] if "context" in meta_info and meta_info["context"] else [[""]]
            }
            
            self.train_dpo.writelines(json.dumps(new_data) + "\n")
            # table = pa.Table.from_pandas(new_data)
            # self.eval_dpo["table"] = pa.concat_tables(
            #     [self.eval_dpo["table"], table]
            # ) 
            self.eval_dpo["self_consistency"] += 1
        
        for item in judgment_exchange_consistency["inconsistency"]:
            new_data = {
                "query": [query],
                "branch": ["\n".join(fb_action_feedback['beam_branch_list'])],
                "branch_chosen": [""],
                "branch_rejected": [""],
                "cur_branch": [item["branch"]],
                "scoring": [item["scoring"]],
                "scoring_chosen": [""],
                "scoring_rejected": [""],
                "judgment_chosen": [item["judgment_chosen"]],
                "judgment_rejected": [item["judgment_rejected"]],
                "response_1": [response_1],
                "response_2": [response_2],
                "judge": [judge_prediction],
                "context": [meta_info["context"]] if "context" in meta_info and meta_info["context"] else [[""]]
            }
            
            self.train_dpo.writelines(json.dumps(new_data) + "\n")
            # table = pa.Table.from_pandas(new_data)
            # self.eval_dpo["table"] = pa.concat_tables(
            #     [self.eval_dpo["table"], table]
            # ) 
            self.eval_dpo["ex_consistency"] += 1
    
    def branch_sorted(self, branch_diff, judge_prediction):
        branch_win = {}
        branch_lose = {}
        branch_tie = {}
        for k, v in branch_diff.items():
            if v > 0:
                branch_win[k] = v
            elif v < 0:
                branch_lose[k] = v
            else:
                branch_tie[k] = v
        
        sorted_win = [c[0] for c in sorted(branch_win.items(), key=lambda item: item[1], reverse=True)]
        sorted_win_reverse = [c[0] for c in sorted(branch_win.items(), key=lambda item: item[1])]
        sorted_lose = [c[0] for c in sorted(branch_lose.items(), key=lambda item: item[1])]
        sorted_lose_reverse = [c[0] for c in sorted(branch_lose.items(), key=lambda item: item[1], reverse=True)]
        sorted_tie = list(branch_tie.keys())
        
        branch_chosen = []
        if judge_prediction == 0:
            branch_chosen.extend(sorted_win)
            branch_chosen.extend(sorted_tie)
            branch_chosen.extend(sorted_lose_reverse)
        elif judge_prediction == 1:
            branch_chosen.extend(sorted_lose)
            branch_chosen.extend(sorted_tie)
            branch_chosen.extend(sorted_win_reverse)
        else:
            branch_chosen.extend(sorted_tie)
            branch_chosen.extend(sorted_win_reverse)
            branch_chosen.extend(sorted_lose_reverse)
        return branch_chosen

    def pairwise_eval(self):
        if "single_agreement" in self.eval_score:
            single_agreement = self.eval_score["single_agreement"]
            self.logger.info(
                "Single Agreement Average Score = {} = {} / {}".format(
                    str(sum(single_agreement) / (len(single_agreement) - self.eval_score["error"])),
                    str(sum(single_agreement)),
                    str(len(single_agreement) - self.eval_score["error"]),
                    )
                )

            self.logger.info("G win {} lose {} tie {}".format(
                self.eval_score['g_win'], self.eval_score['g_lose'], self.eval_score['g_tie']))
            self.logger.info("win {} lose {} tie {}".format(
                self.eval_score['win'], self.eval_score['lose'], self.eval_score['tie']))

            if "agreement" in self.eval_score:
                agreement = self.eval_score["agreement"]
                consistency = self.eval_score["consistency"]
                self.logger.info("Error = {}".format(self.eval_score["error"]))
                self.logger.info(
                    "Agreement Average Score = {} = {} / {}".format(
                        str(
                            sum(agreement) / (len(agreement) - self.eval_score["error"])
                        ),
                        str(sum(agreement)),
                        str(len(agreement) - self.eval_score["error"]),
                    )
                )
                self.logger.info(
                    "Consistency Average Score = {} = {} / {}".format(
                        str(
                            sum(consistency)
                            / (len(consistency) - self.eval_score["error"])
                        ),
                        str(sum(consistency)),
                        str(len(consistency) - self.eval_score["error"]),
                    )
                )

    def pairwise_eval_sft_serialize(self):
        for key in self.eval_sft.keys():
            if "table" not in key:
                self.logger.info("key = {}, value = {}".format(key, self.eval_sft[key]))
        
        self.logger.info("SFT Table size = {}".format(len(self.eval_sft["table"])))
        pq.write_table(self.eval_sft["table"], self.train_sft_file)
        
    def pairwise_eval_dpo_serialize(self):
        for key in self.eval_dpo.keys():
            if "table" not in key:
                self.logger.info("key = {}, value = {}".format(key, self.eval_dpo[key]))
        
        self.logger.info("DPO Table size = {}".format(len(self.eval_dpo["table"])))
        pq.write_table(self.eval_dpo["table"], self.train_dpo_file)   
        
    def serialize(self):
        
        if self.action_runner == "pairwise_eval":
            self.pairwise_eval()
            # self.pairwise_eval_sft_serialize()
            # self.pairwise_eval_dpo_serialize()
            
        