#!/usr/bin/env python3
# build_reasoning_dataset_for_categories.py
import pickle, csv, json, random, pathlib
from tqdm import tqdm
from synthetic_reasoning import LanguageRule, LanguageFact

GROUP = "contradictory"  # "independent", "equivalent", "contradictory", "alternative", "complementary"
SEED = 42

random.seed(SEED)

def fact_to_triplets(fact):
    trips = []
    for attr in fact.generic_attributes:
        trips.append({
            "subject": fact.subject,
            "relation": "is",
            "object": attr,
            "text": f"{fact.subject} is {attr}",
        })
    return trips


def facts_group_to_triplet_json(facts):
    """
    Convert a list[LanguageFact] → JSON string (to be stored in a single CSV cell)
    """
    all_trips = []
    for f in facts:
        all_trips.extend(fact_to_triplets(f))
    return json.dumps(all_trips, ensure_ascii=False)


def facts_group_to_text(facts):
    return " ".join(str(f) for f in facts)


def label_options(opts):
    """return dict {'A':opt0,'B':opt1,'C':opt2,'D':opt3}"""
    return {lab: opt for lab, opt in zip("ABCD", opts)}

def format_options(opt_dict):
    return "Options: " + ", ".join(f"{k}) {v}" for k, v in opt_dict.items())

QUESTION = "Which of the following options can be inferred based on the given facts and rules?"

def build_model_input_type1(type, mode="normal"):
    if type in ["independent", "equivalent", "contradictory"]:
        PICKLE_PATH = pathlib.Path("/path/to/reasoning_base.pkl")      # ← your .pkl dataset
    elif type in ["alternative", "complementary","entailment"]:
        PICKLE_PATH = pathlib.Path(f"/path/to/reasoning_{type}.pkl")      # ← your .pkl dataset
    else:
        raise ValueError(f"Type {type} not supported")
    
    CSV_PATH    = pathlib.Path(f"/path/to/reasoning_meta/reasoning_{type}_dataset.csv") 
    # ──────────────────────────────────────────────────────────────────────────── #
    # Load dataset (list of dicts) from pickle
    # ──────────────────────────────────────────────────────────────────────────── #
    print("Loading dataset …")
    with open(PICKLE_PATH, "rb") as fh:
        dataset = pickle.load(fh)

    rows = []
    for idx, sample in tqdm(enumerate(dataset, start=1), total=len(dataset)):
        rules_obj          = sample["rules"]            # list[LanguageRule]
        test_facts         = sample["test_facts"]       # list[LanguageFact] (len==3)
        target_facts       = sample["target_facts"]     # list[LanguageFact] (len==4)
        distractor_groups  = sample["distractor_groups"]  # list[list[LanguageFact]] (len==3)

        group4 = None
        if type == "independent":
            if mode == "normal":
                group1 = list(distractor_groups[0])
                group2 = list(distractor_groups[1])
                group3 = list(distractor_groups[2])
            elif mode == "easy":
                group1 = []
                group2 = random.sample(distractor_groups[1], k=1)
                group3 = random.sample(distractor_groups[2], k=1)
            else:
                raise ValueError(f"Mode {mode} not supported")
            if test_facts[0] not in group1:
                group1.append(test_facts[0])
        elif type == "equivalent":
            if mode == "normal":
                group1 = list(distractor_groups[0])
                group2 = list(distractor_groups[0])
                group3 = list(distractor_groups[0])
            elif mode == "easy":
                group1 = []
                group2 = []
                group3 = []
            else:
                raise ValueError(f"Mode {mode} not supported")
            if test_facts[0] not in group1:
                group1.append(test_facts[0])
            if test_facts[0] not in group2:
                group2.append(test_facts[0])
            if test_facts[0] not in group3:
                group3.append(test_facts[0])
        elif type == "contradictory":
            if mode == "normal":
                group1 = list(distractor_groups[0])
                group2 = list(distractor_groups[1])
                group3 = list(distractor_groups[2])
            elif mode == "easy":
                group1 = []
                group2 = []
                group3 = []
            else:
                raise ValueError(f"Mode {mode} not supported")
            if test_facts[0] not in group1:
                group1.append(test_facts[0])
            if test_facts[1] not in group2:
                group2.append(test_facts[1])
            if test_facts[2] not in group3:
                group3.append(test_facts[2])
        elif type == "alternative" or type == "entailment":
            if mode == "normal":
                group1 = list(distractor_groups[0])
                group2 = list(distractor_groups[1])
                group3 = list(distractor_groups[2])
            elif mode == "easy":
                group1 = []
                group2 = []
                group3 = []
            else:
                raise ValueError(f"Mode {mode} not supported")
            if test_facts[0] not in group1:
                group1.append(test_facts[0][0])
            if test_facts[1] not in group2:
                group2.append(test_facts[0][1])
            if test_facts[2] not in group3:
                group3.append(test_facts[0][2])
        elif type == "complementary":
            if mode == "normal":
                group1 = list(distractor_groups[0])
                group2 = list(distractor_groups[1])
                group3 = list(distractor_groups[2])
            elif mode == "easy":
                group1 = []
                group2 = []
                group3 = []
            else:
                raise ValueError(f"Mode {mode} not supported")
            if test_facts[0] not in group1:
                group1.append(test_facts[0][0])
            if test_facts[1] not in group2:
                group2.append(test_facts[0][1])
            if test_facts[2] not in group3:
                group3.append(test_facts[0][2])
            group4 = group1 + group2 + group3
        else:
            raise ValueError(f"Type {type} not supported")


        # ── Triplets & texts for the three modalities
        random.shuffle(group1)
        random.shuffle(group2)
        random.shuffle(group3)
        modality1_triplet = facts_group_to_triplet_json(group1)
        modality2_triplet = facts_group_to_triplet_json(group2)
        modality3_triplet = facts_group_to_triplet_json(group3)

        modality1_text = facts_group_to_text(group1)
        modality2_text = facts_group_to_text(group2)
        modality3_text = facts_group_to_text(group3)

        if type == "complementary":
            random.shuffle(group4)
            modality4_triplet = facts_group_to_triplet_json(group4)
            modality4_text = facts_group_to_text(group4)

        # ── Rules (plain text)
        random.shuffle(rules_obj)
        rules_text = " ".join(str(r) for r in rules_obj)

        # ── Options (shuffle, then label)
        if GROUP == 'contradictory':
            roles_and_ents = [
                ("image",       target_facts[0]),
                ("audio",       target_facts[1]),
                ("text",        target_facts[2]),
                ("distractor",  target_facts[3]),
            ]
            random.shuffle(roles_and_ents) 
            options = [str(v[1]) for v in roles_and_ents]
            labelled_options = label_options(options)
            options_str  = format_options(labelled_options)
            question_text = f"Question: {QUESTION}  {options_str}"
            option_role_map  = {lab: role for lab, (role, _) in zip('ABCD', roles_and_ents)}
            rows.append({
                "id": f"{type}_{idx}",
                "modality1_triplet": modality1_triplet,
                "modality2_triplet": modality2_triplet,
                "modality3_triplet": modality3_triplet,
                "modality1_text": modality1_text,
                "modality2_text": modality2_text,
                "modality3_text": modality3_text,
                "rules": rules_text,
                "questions": QUESTION,
                "options": json.dumps(labelled_options, ensure_ascii=False),
                "question_text": question_text,
                "option_role_map": json.dumps(option_role_map, ensure_ascii=False),
            })
        else:
            option_strings = [str(f) for f in target_facts] 
            correct_answer_text = option_strings[0]              
            random.shuffle(option_strings)                           
            labelled_options = label_options(option_strings)
            correct_label = next(k for k, v in labelled_options.items() if v == correct_answer_text)
            options_str = format_options(labelled_options)
            question_text = f"Question: {QUESTION}  {options_str}"

            # ── Append one CSV row
            if type == "complementary":
                rows.append({
                    "id": f"{type}_{idx}",
                    "modality1_triplet": modality1_triplet,
                    "modality2_triplet": modality2_triplet,
                    "modality3_triplet": modality3_triplet,
                    "all_triplet":       modality4_triplet,
                    "modality1_text": modality1_text,
                    "modality2_text": modality2_text,
                    "modality3_text": modality3_text,
                    "all_text":       modality4_text,
                    "rules": rules_text,
                    "questions": QUESTION,
                    "options": json.dumps(labelled_options, ensure_ascii=False),
                    "question_text": question_text,
                    "correct_answer": correct_label,
                })
            else:
                rows.append({
                    "id": f"{type}_{idx}",
                    "modality1_triplet": modality1_triplet,
                    "modality2_triplet": modality2_triplet,
                    "modality3_triplet": modality3_triplet,
                    "modality1_text": modality1_text,
                    "modality2_text": modality2_text,
                    "modality3_text": modality3_text,
                    "rules": rules_text,
                    "questions": QUESTION,
                    "options": json.dumps(labelled_options, ensure_ascii=False),
                    "question_text": question_text,
                    "correct_answer": correct_label,
                })
    
    print(f"Writing {len(rows)} rows to {CSV_PATH} …")
    fieldnames = list(rows[0].keys())

    CSV_PATH.parent.mkdir(parents=True, exist_ok=True)
    with open(CSV_PATH, "w", newline="", encoding="utf-8") as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

if __name__ == "__main__":
    build_model_input_type1(type = GROUP, mode = "easy")
    print("Done!")
