from prompt.omni_prompt import *
from prompt.criteria_prompt import *
from utils import *
import random
import uuid
import json
import os

class PairedData:

    def __init__(self, json_item):
        
        assert isinstance(json_item, dict) and isinstance(json_item['query'], dict) and isinstance(json_item['chosen'], dict) and isinstance(json_item['rejected'], dict)

        self._id = json_item['id']
        self.suffix = json_item['suffix']
        self.query = json_item['query']
        self.chosen = json_item['chosen']
        self.rejected = json_item['rejected']
    
    @property
    def id(self):
        return self._id

    def dumps(self):
        return {
            "id": self._id,
            "suffix": self.suffix,
            "query": self.query,
            "chosen": self.chosen,
            "rejected": self.rejected
        }


class Traj:

    def __init__(self):
        
        self.paired_data = None

        self.conversations = []
        self.llm_response = []

        self.criteria_list = None
        self.criteria_step = None

        self.judge_pair = None
        self.judge = None
        
        self.ranking_pair = None
        self._answer = None
    
    @property
    def id(self):
        return self.paired_data.id
    
    @property
    def answer(self):
        return self._answer

    @property
    def query(self):
        return self.paired_data.query

    @property
    def response_1(self):

        if self.answer == 0:
            return self.paired_data.chosen
        elif self.answer == 1:
            return self.paired_data.rejected
        else:
            raise ValueError(f"Unexpected answer value: {self.answer}")

    @property
    def response_2(self):
        
        if self.answer == 1:
            return self.paired_data.chosen
        elif self.answer == 0:
            return self.paired_data.rejected
        else:
            raise ValueError(f"Unexpected answer value: {self.answer}")

    def loads(self, data: dict, shuffle: bool=False):
        
        if self._answer is not None:
            raise ValueError(f"When answer was set, you cannot load data again ! check your code !")
        
        if shuffle:
            if "answer" in data:
                raise ValueError(f"You want to ramdom select who is the Response A or B, but you provide answer in `data`! It's conflict, maybe there is a bug! Check your code !")
            if random.random() > 0.5:
                self._answer = 0
            else:
                self._answer = 1
        else:
            self._answer = data["answer"]
            assert self._answer in [0, 1]
        
        FORBIDDEN_LOADS_ITEM = ["answer","conversations","llm_response"]
        for key, value in data.items():
            
            if key == "paired_data":
                assert isinstance(value, PairedData)
            if key in FORBIDDEN_LOADS_ITEM:
                continue
            setattr(self, key, value)
    
    def dumps(self):
        
        results = {}
        for attr_name, attr_value in self.__dict__.items():
            
            # if attr_name in ["conversations", "llm_response"]:
            #     continue
            
            if attr_name == "paired_data":
                results[attr_name] = attr_value.dumps()
                continue
            
            if attr_value is not None:
                if attr_name == "_answer":
                    results["answer"] = attr_value
                else:
                    results[attr_name] = attr_value
        
        return results
    
    def build_exploration_criteria_conversaion(self):
        # import pdb; pdb.set_trace()
        prompt = build_exploration_criteria_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'], candidate_criteria=criteria2str(self.criteria_list), think=True)
        conversation = [{'role': 'user', 'content': prompt}]
        self.conversations.append(conversation)
    def build_exploration_judge_conversaion(self):
        for based_criterion in self.criteria_list:
            prompt = build_exploration_judge_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'], based_criterion=criteria2str([based_criterion]))
            conversation = [{'role': 'user', 'content': prompt}]
            self.conversations.append(conversation)
    def build_stepwise_criteria_conversation(self, step=0):

        self.criteria_step = step
        prompt = build_stepwise_criteria_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'], candidate_criteria=criteria2str(self.criteria_list), think=False)
        conversation = [{'role': 'user', 'content': prompt}]
        self.conversations.append(conversation)
    
    def build_criteria_n_conversation(self, step=0, criteria_n=3):
        
        self.criteria_step = step
        prompt = build_criteria_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'], criteria_n=criteria_n)
        conversation = [{'role': 'user', 'content': prompt}]
        self.conversations.append(conversation)


    def build_direct_judge_criteria_conversation(self, step=0, modality="language"):
        
        self.criteria_step = step
        if modality == "language":
            
            # prompt = build_direct_judge_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'], candidate_criteria=criteria2str(self.criteria_list), think=True)
            prompt = build_direct_judge_without_candidate_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'])
            conversation = [{'role': 'user', 'content': prompt}]

        elif modality == "video":
            if 'videos' in self.response_1:
                prefix_pm, inffix_pm, suffix_pm = build_direct_judge_split_without_candidate_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'], task="gen", think=True)
                conversation = [
                    {
                        'role': 'user',
                        'content': [
                            {"type": "text", "text": prefix_pm},
                            {
                                "type": "video",
                                "video": self.response_1['videos'][0],
                                "min_pixels": 256 * 256,
                                "max_pixels": 256 * 256,
                                "total_pixels": 24 * 256 * 256,
                            },
                            {"type": "text", "text": inffix_pm},
                            {
                                "type": "video",
                                "video": self.response_2['videos'][0],
                                "min_pixels": 256 * 256,
                                "max_pixels": 256 * 256,
                                "total_pixels": 24 * 256 * 256,
                            },
                            {"type": "text", "text": suffix_pm}
                        ]
                    }
                ]
            else:
                raise NotImplementedError
        
        elif modality == "image":
            if 'images' in self.query:
                
                prefix_pm,  suffix_pm = build_direct_judge_split_without_candidate_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'], task="und", think=True)
                
                conversation = [
                    {
                        'role': 'user',
                        'content': [
                            {"type": "text", "text": prefix_pm + "<Image>: "},
                            {
                                "type": "image", 
                                "image": self.query['images'][0],
                                "min_pixels": 512 * 512,
                                "max_pixels": 512 * 512,
                            },
                            {"type": "text", "text": suffix_pm}
                        ]
                    }
                ]
            elif 'images' in self.response_1:
                prefix_pm, inffix_pm, suffix_pm = build_direct_judge_split_without_candidate_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'], task="gen", think=True)
                conversation = [
                    {
                        'role': 'user',
                        'content': [
                            {"type": "text", "text": prefix_pm},
                            {
                                "type": "image",
                                "image": self.response_1['images'][0],
                                "min_pixels": 512 * 512,
                                "max_pixels": 512 * 512,
                            },
                            {"type": "text", "text": inffix_pm},
                            {
                                "type": "image",
                                "image": self.response_2['images'][0],
                                "min_pixels": 512 * 512,
                                "max_pixels": 512 * 512,
                            },
                            {"type": "text", "text": suffix_pm}
                        ]
                    }
                ]
            else:
                raise NotImplementedError
        elif modality == "audio":
            
            if 'audios' in self.query:
                
                prefix_pm,  suffix_pm = build_direct_judge_split_without_candidate_prompt(query=self.query['content'], response_1=self.response_1['content'], response_2=self.response_2['content'], task="und", think=True)

                conversation = [
                    {
                        'role': 'user', 
                        'content': [
                            {"type": "text", "text": prefix_pm},
                            {
                                "type": "audio",
                                "audio": self.query['audios'][0]
                            },
                            {"type": "text", "text": suffix_pm}
                        ]
                    }
                ]
            else:
                prefix_pm, inffix_pm, suffix_pm = build_direct_judge_split_without_candidate_prompt(query=self.query['content'], response_1="", response_2="", task="gen", think=True)
                conversation = [
                    {
                        'role': 'user',
                        'content': [
                            {"type": "text", "text": prefix_pm},
                            {
                                "type": "audio",
                                "audio": self.response_1['audios'][0],
                            },
                            {"type": "text", "text": inffix_pm},
                            {
                                "type": "audio",
                                "audio": self.response_2['audios'][0],
                            },
                            {"type": "text", "text": suffix_pm}
                        ]
                    }
                ]
        # unified ######################################################################
        if os.environ.get("TEST_UNIFIED", "False").lower() in ["1", "true", "yes"]:
            print(f"TEST_UNIFIED !!!.")
            conversation = None
            # 
            prompt_text = (
                "You are given an image and a question related to it. Your job is to evaluate the two responses based on these five factors:\n\n"
                "1. Accuracy of Object Descriptions: Review how accurately the objects are described in the responses, ensuring they match those in the ground truth. Be mindful of irrelevant or incorrect objects being mentioned.\n\n"
                "2. Relationship Between Objects: Check if the response properly describes how the objects relate to each other, reflecting their actual positions or interactions, as seen in the image.\n\n"
                "3. Description of Attributes: Assess how well the response captures the attributes (e.g., size, color, shape) of the objects in the image, in line with the ground truth.\n\n"
                "4. Helpfulness: Consider whether the response offers useful information that enhances the understanding of the image. Does it add context or provide extra insights? Also, evaluate whether it follows the instructions given in the prompt.\n\n"
                "5. Ethical Concerns: Review the response to ensure it avoids sensitive, harmful, or inappropriate content. The response should be fair, respect privacy, and be free of bias or offensive material.\n\n"
                "After evaluating both answers, determine which one is better based on these factors and clearly state your decision, such as 'Answer 1 is better' or 'Answer 2 is better.'\n\n"
                f"Question: {self.query['content']}\n"
                f"Answer 1: {self.response_1['content']}\n"
                f"Answer 2: {self.response_2['content']}\n"
            )
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": self.query['images'][0]},
                        {"type": "text", "text": prompt_text},
                    ]
                }
            ]
        ####################################################################################
        self.conversations.append(conversation)

    def criteria_dumps(self, criteria_n:int= None):
        # import pdb; pdb.set_trace()
        criteria_list = extract_criteria_from_text(self.llm_response[0])
        self.criteria_list = [{"id": str(uuid.uuid4()), "content": c} for c in criteria_list]
        _dump_data = self.dumps()
        if self.criteria_step and self.criteria_step != 0:
            _dump_data.pop("paired_data")
        if criteria_n is not None and isinstance(criteria_n, int):
            assert len(self.criteria_list) == criteria_n
        return json.dumps(_dump_data, ensure_ascii=False)
    
    def criteria_n_dumps(self, criteria_n_number):
        criteria_list = extract_criteria_n_from_text(clear_think(self.llm_response[0]))
        self.criteria_list = []
        for c in criteria_list:
            self.criteria_list.append({"id": str(uuid.uuid4()), "content": c})
        assert len(criteria_list) == criteria_n_number
        _dump_data = self.dumps()
        return json.dumps(_dump_data, ensure_ascii=False)
    
    def criteria_and_judge_dumps(self):

        parsed_result = parse_direct_judge_output(self.llm_response[0])
        self.criteria_list = []
        self.judge_pair = {
            "judge_a_list": [],
            "judge_b_list": []
        }
        for c in parsed_result['criteria']:
            self.criteria_list.append({"id": str(uuid.uuid4()), "content": c['criterion']})
            self.judge_pair['judge_a_list'].append(c['judge_A'])
            self.judge_pair['judge_b_list'].append(c['judge_B'])
        assert len(self.criteria_list) == 3, f"criteria_and_judge_dumps(): need the the pared criteria_num to be exactly 3. but got: [{len(self.criteria_list)}]"
        self.judge = [self.llm_response[0]]
        _dump_data = self.dumps()
        return json.dumps(_dump_data, ensure_ascii=False)
    
    def criteria_and_judge_dumps_simple(self):
        self.judge = [self.llm_response[0]]
        _dump_data = self.dumps()
        return json.dumps(_dump_data, ensure_ascii=False)

    def build_stepwise_judge_conversation(self, step):
        prompt = build_stepwise_judge_prompt(self.paired_data.query['content'], self.paired_data.response_1['content'], self.paired_data.response_2['content'], criteria2str(self.criteria_list[step*3:step*3+3]))
        conversation = [{'role': 'user', 'content': prompt}]
        self.conversations.append(conversation)
    
    def build_criteria_n_judge_conversation(self, n):
        prompt = build_judge_prompt(self.query['content'], self.response_1['content'], self.response_2['content'], single_criteria2str(self.criteria_list[n]))
        conversation = [{'role': 'user', 'content': prompt}]
        self.conversations.append(conversation)
    
    def judge_dumps(self):
        self.judge_pair = {
            "judge_a_list": [],
            "judge_b_list": [],
        }
        for judge_str in self.llm_response:
            judge_str = clear_think(judge_str)
            self.judge_pair['judge_a_list'].extend(extract_judge_pair(judge_str)['judge_a_list'])
            self.judge_pair['judge_b_list'].extend(extract_judge_pair(judge_str)['judge_b_list'])
        
        self.judge = self.llm_response
        _dump_data = self.dumps()
        return json.dumps(_dump_data, ensure_ascii=False)

    def criteria_n_judge_dumps(self):

        self.judge_pair = {
            "judge_a_list": [],
            "judge_b_list": [],
        }
        for judge_str in self.llm_response:
            judge = extract_judge_only(clear_think(judge_str))
            self.judge_pair['judge_a_list'].append(judge['judge_a'])
            self.judge_pair['judge_b_list'].append(judge['judge_b'])
        
        self.judge = self.llm_response
        _dump_data = self.dumps()
        return json.dumps(_dump_data, ensure_ascii=False)

    def build_refinement_conversation(self, judge_pair, modality="language"):
        for judge in judge_pair['judge_a_list']:
            
            if modality == "language":
                prompt = build_correct_prompt(self.query['content'], self.response_1['content'], judge)
                conversation = [{'role': 'user', 'content': prompt}]
            elif modality == "vision":
                prefix_pm,  suffix_pm = build_correct_split_prompt(query=self.query['content'], response=self.response_1['content'], judge=judge, task="und", think=True)
                conversation = [
                    {
                        'role': 'user',
                        'content': [
                            {"type": "text", "text": prefix_pm + "<Image>: "},
                            {
                                "type": "image", 
                                "image": self.query['images'][0],
                                "min_pixels": 512 * 512,
                                "max_pixels": 512 * 512,
                            },
                            {"type": "text", "text": suffix_pm}
                        ]
                    }
                ]
            else:
                raise Exception("Invalid")
            
            
            self.conversations.append(conversation)
        
        for judge in judge_pair['judge_b_list']:
        
            if modality == "language":
                prompt = build_correct_prompt(self.query['content'], self.response_2['content'], judge)
                conversation = [{'role': 'user', 'content': prompt}]
            elif modality == "vision":
                prefix_pm,  suffix_pm = build_correct_split_prompt(query=self.query['content'], response=self.response_2['content'], judge=judge, task="und", think=True)
                conversation = [
                    {
                        'role': 'user',
                        'content': [
                            {"type": "text", "text": prefix_pm + "<Image>: "},
                            {
                                "type": "image", 
                                "image": self.query['images'][0],
                                "min_pixels": 512 * 512,
                                "max_pixels": 512 * 512,
                            },
                            {"type": "text", "text": suffix_pm}
                        ]
                    }
                ]
            else:
                raise Exception("Invalid")
            self.conversations.append(conversation)

    def refinement_dumps(self):
        
        self.refinement_list = {
            "refinement_a": self.llm_response[:len(self.llm_response)//2],
            "refinement_b": self.llm_response[len(self.llm_response)//2:]
        }
        _dump_data = self.dumps()
        return json.dumps(_dump_data, ensure_ascii=False)

    def build_ranking_conversation(self, refinement_pair:dict):

        score_raw_a_conversation = [
            {"role": "user", "content": self.query['content']},
            {"role": "assistant", "content": self.response_1['content']}
        ]
        score_raw_b_conversation = [
            {"role": "user", "content": self.query['content']},
            {"role": "assistant", "content": self.response_2['content']}
        ]
        self.conversations.append(score_raw_a_conversation)
        self.conversations.append(score_raw_b_conversation)

        for refinement in refinement_pair['refinement_a']:
            conversation = [
                {"role": "user", "content": self.query['content']},
                {"role": "assistant", "content": clear_think(refinement)}
            ]
            self.conversations.append(conversation)
        
        for refinement in refinement_pair['refinement_b']:
            conversation = [
                {"role": "user", "content": self.query['content']},
                {"role": "assistant", "content": clear_think(refinement)}
            ]
            self.conversations.append(conversation)
    
    def ranking_dumps(self):

        ranking_raw = self.llm_response[:2]
        self.llm_response = self.llm_response[2:]
        self.ranking_pair = {
            "ranking_raw_a": ranking_raw[0],
            "ranking_raw_b": ranking_raw[1],
            "ranking_a": self.llm_response[:len(self.llm_response)//2],
            "ranking_b": self.llm_response[len(self.llm_response)//2:]
        }
        _dump_data = self.dumps()
        return json.dumps(_dump_data, ensure_ascii=False)
    
    def raw_data_dumps(self):

        if self.answer == 0:
            chosen = self.response_1
            rejected = self.response_2
        elif self.answer == 1:
            chosen = self.response_2
            rejected = self.response_1
        else:
            raise ValueError(f"Invalid answer: {self.answer}")
        
        _dump_data = {
            "conversations": [
                {"role":"user", "content": self.paired_data.query['content']}
            ],
            "chosen":{"role":"assistant", "content": chosen['content']},
            "rejected":{"role":"assistant", "content": rejected['content']},
            "suffix": self.paired_data.suffix,
            "id": self.paired_data.id
        }
        
        return json.dumps(_dump_data, ensure_ascii=False)
