import json
import random
import sys
import os
sys.path.append("")
from datasets import Dataset, Features, Value, Sequence
from global_vars import PROMPT, TASK_NAME
from utils import re_response_number_extraction, re_response_pro_con_extraction, accuracy, mean_absolute_error, normalized_inverse_error, llm_as_judge

def load_json(path=""):
    return json.load(open(path))

def get_pool(item):
    pros = item["pros"]
    cons = item["cons"]

    pro_pool = []
    con_pool = []
    for opinion in pros:
        pro_pool.extend([{"opinion": opinion["opinion"], "expanded": expanded, "stance": "pro"} for expanded in opinion["expanded"]])
    for opinion in cons:
        con_pool.extend([{"opinion": opinion["opinion"], "expanded": expanded, "stance": "con"} for expanded in opinion["expanded"]])
    random.shuffle(pro_pool)
    random.shuffle(con_pool)
    return pro_pool, con_pool


def decide_pro_con_counts(m: int):
    if m < 2:
        raise ValueError("m must be >= 2 to be split into pro/con")

    # if m is odd, naturally not even
    if m % 2 == 1:
        pro_count = random.randint((m // 2) + 1, m)  # at least one more than half
        con_count = m - pro_count
        return pro_count, con_count

    # if m is even, avoid m/2 : m/2 situation
    split = random.randint(1, m - 1)   # random number between 1 and m-1
    if split == m // 2:
        # if it is exactly half, move one to one side
        split += 1 if random.random() < 0.5 else -1
    pro_count = split
    con_count = m - split
    return pro_count, con_count

class OpinionCounting:
    def __init__(self, num_data_per_topic=5, maximum_opinion_per_data=8, minimum_opinion_per_data=2):
        self.task_name = TASK_NAME.OPINION_COUNTING
        self.prompt = PROMPT.opinion_counting
        self.num_data_per_topic = num_data_per_topic
        self.maximum_opinion_per_data = maximum_opinion_per_data
        self.minimum_opinion_per_data = minimum_opinion_per_data

    def get_dataset(self):
        out_data = []
        data = load_json()

        for item in data:
            topic = item["topic"]
            pro_pool, con_pool = get_pool(item)
            merged_pool = pro_pool + con_pool  # in opinion counting, we dont need to distinguish pro and con
            for _ in range(self.num_data_per_topic):
                num_opinion = random.randint(self.minimum_opinion_per_data, self.maximum_opinion_per_data)
                random.shuffle(merged_pool)
                selected_opinions = merged_pool[:num_opinion]
                gt_num_opinion = len(set([opinion["opinion"] for opinion in selected_opinions])) 

                concatenated_opinions = "\n".join([opinion["expanded"] for opinion in selected_opinions])

                dialog = {
                    "messages": [
                        {"role": "system", "content": self.prompt.format(topic=item["topic"])},
                        {"role": "user", "content": f"Opinions: {concatenated_opinions}"},
                        {"role": "assistant", "content": "Your answer: "} # add rating to all the response, if test, we need to empty the rating
                    ]
                }
                out_data.append({
                    "topic": topic,
                    "concatenated_opinions": concatenated_opinions,
                    "gt_num_opinion": gt_num_opinion,
                    "selected_opinions": selected_opinions,
                    "dialog": dialog
                })

        features = Features({
        "topic": Value("string"),
        "concatenated_opinions": Value("string"),
        "gt_num_opinion": Value("int32"),
        "selected_opinions": Value("string"),
        "dialog": Features({
            "messages": Sequence(
                Features({
                    "role": Value("string"),
                    "content": Value("string"),
                })
            )
            })
        })
        self.ds = Dataset.from_list(out_data, features=features)
        self.gt = [item["gt_num_opinion"] for item in out_data]
        self.dialogs = [item["dialog"] for item in out_data]

    def evaluation(self, pred):
        valid_count = 0
        filtered_pred = []
        filtered_gt = []
        saved = []
        # filtering out the None
        for p, g, dialog in zip(pred, self.gt, self.dialogs):
            extracted = re_response_number_extraction(llm_as_judge(p, self.task_name))
            filtered_pred.append(int(extracted))
            filtered_gt.append(g)
            valid_count += 1
            saved.append({
                    "dialog": dialog,
                    "output": p,
                    "gt": g,
                    "extracted": extracted
                })

        return_dict = {
            "accuracy": accuracy(filtered_pred, filtered_gt),
            "mae": mean_absolute_error(filtered_pred, filtered_gt),
            "normalized_inverse_error": normalized_inverse_error(filtered_pred, filtered_gt),
            "valid_count": valid_count,
            # "pred": filtered_pred,
            # "gt": filtered_gt
        }
        return return_dict, saved
    

    
class OpinionMatching:
    def __init__(self, num_data_per_topic=5):
        self.task_name = TASK_NAME.OPINION_MATCHING
        self.prompt = PROMPT.opinion_matching
        self.num_data_per_topic = num_data_per_topic

    def get_dataset(self):
        out_data = []
        data = load_json()

        for item in data:
            topic = item["topic"]
            options = []
            option_stance_mapping = []
            for pro in item["pros"]:
                options.append(pro["opinion"])
                option_stance_mapping.append("pro")
            for con in item["cons"]:
                options.append(con["opinion"])
                option_stance_mapping.append("con")

            
            options_str = "\n".join([f"{i}) {opt}" for i, opt in enumerate(options)])


            pro_pool, con_pool = get_pool(item)
            merged_pool = pro_pool + con_pool
            random.shuffle(merged_pool)
            for _ in range(self.num_data_per_topic):
                selected = random.choice(merged_pool)

                expanded = selected["expanded"]
                gt_opinion = selected["opinion"]

                assert gt_opinion in options, f"gt_opinion {gt_opinion} not in options {options}"

                dialog = {
                    "messages": [
                        {"role": "system", "content": self.prompt.format(topic=topic)},
                        {"role": "user", "content": f"Expanded opinion: {expanded}\nOptions: {options_str}"},
                        {"role": "assistant", "content": "Your answer: "}
                    ]
                }
                out_data.append({
                    "topic": topic,
                    "expanded": expanded,
                    "gt_opinion": gt_opinion,
                    "option_stance_mapping": option_stance_mapping,
                    "options": options_str,
                    "gt_opinion_idx": options.index(gt_opinion),
                    "dialog": dialog
                })
        features = Features({
        "topic": Value("string"),
        "expanded": Value("string"),
        "gt_opinion": Value("string"),
        "gt_opinion_idx": Value("int32"),
        "option_stance_mapping": Sequence(Value("string")),
        "options": Value("string"),
        "dialog": Features({
            "messages": Sequence(
                Features({
                    "role": Value("string"),
                    "content": Value("string"),
                })
            )
            })
        })
        self.ds = Dataset.from_list(out_data, features=features)
        self.gt = [item["gt_opinion_idx"] for item in out_data]
        self.dialogs = [item["dialog"] for item in out_data]

    def evaluation(self, pred):
        valid_count = 0
        filtered_pred = []
        filtered_gt = []
        stance_correctness = []
        saved = []
        # filtering out the None
        for p, g, dialog in zip(pred, self.gt, self.dialogs):
            extracted = re_response_number_extraction(llm_as_judge(p, self.task_name))
            filtered_pred.append(int(extracted))
            filtered_gt.append(g)
            gt_stance = self.ds["option_stance_mapping"][int(g)]
            pred_stance = self.ds["option_stance_mapping"][int(extracted)]
            if gt_stance == pred_stance:
                stance_correctness.append(1)
            else:
                stance_correctness.append(0)
            valid_count += 1
            saved.append({
                    "dialog": dialog,
                    "output": p,
                    "gt": g,
                    "extracted": extracted
                })

        return_dict = {
            "accuracy": accuracy(filtered_pred, filtered_gt),
            "valid_count": valid_count,
            "stance_correctness": sum(stance_correctness) / len(stance_correctness),
            # "pred": filtered_pred,
            # "gt": filtered_gt
        }
        return return_dict, saved

class PolarityCheck:
    def __init__(self, num_data_per_topic=5, num_opinion_per_item=8):
        self.task_name = TASK_NAME.POLARITY_CHECK
        self.prompt = PROMPT.polarity_check
        self.num_data_per_topic = num_data_per_topic
        self.num_opinion_per_item = num_opinion_per_item

    def get_dataset(self):
        out_data = []
        data = load_json()

        for item in data:
            topic = item["topic"]

            for _ in range(self.num_data_per_topic):
                pro_pool, con_pool = get_pool(item)
                pro_count, con_count = decide_pro_con_counts(self.num_opinion_per_item)
                pro_selected = pro_pool[:pro_count]
                con_selected = con_pool[:con_count]

                final_stance = "pro" if pro_count > con_count else "con"

                opinions = [item["expanded"] for item in pro_selected + con_selected]
                random.shuffle(opinions)
                concatenated_opinions = "\n".join(opinions)
                dialog = {
                    "messages": [
                        {"role": "system", "content": self.prompt.format(topic=topic)},
                        {"role": "user", "content": f"Expanded opinions: {concatenated_opinions}"},
                        {"role": "assistant", "content": "Your answer: "}
                    ]
                }
                out_data.append({
                    "topic": topic,
                    "concatenated_opinions": concatenated_opinions,
                    "final_stance": final_stance,
                    "pro_num": pro_count,
                    "con_num": con_count,
                    "dialog": dialog
                })
        features = Features({
        "topic": Value("string"),
        "concatenated_opinions": Value("string"),
        "final_stance": Value("string"),
        "pro_num": Value("int32"),
        "con_num": Value("int32"),
        "dialog": Features({
            "messages": Sequence(
                Features({
                    "role": Value("string"),
                    "content": Value("string"),
                })
            )
            })
        })
        self.ds = Dataset.from_list(out_data, features=features)
        self.gt = [item["final_stance"] for item in out_data]
        self.dialogs = [item["dialog"] for item in out_data]

    def evaluation(self, pred):
        valid_count = 0
        filtered_pred = []
        filtered_gt = []
        filtered_dialogs = []

        saved = []
        # filtering out the None
        for p, g, dialog in zip(pred, self.gt, self.dialogs):
            extracted = llm_as_judge(p, self.task_name)
            filtered_pred.append(extracted)
            filtered_gt.append(g)
            filtered_dialogs.append(dialog)
            valid_count += 1

            saved.append({
                    "dialog": dialog,
                    "output": p,
                    "gt": g,
                    "extracted": extracted
                })

        return_dict = {
            "accuracy": accuracy(filtered_pred, filtered_gt),
            "valid_count": valid_count,
            # "pred": filtered_pred,
            # "gt": filtered_gt
        }
        return return_dict, saved


if __name__ == "__main__":
    pass