"""
Adapted from https://github.com/stanford-crfm/helm
"""

import random
import dataclasses
from copy import copy
from typing import List, Dict, Literal, Tuple, Optional
from dataclasses import dataclass
import pickle
import json
from tqdm import tqdm

SEED = 42
random.seed(SEED)

@dataclass(frozen=True)
class LanguageLogicalStatement:
    subject: str  # e.g. either the individual or group to which this statement applies
    subject_category: str  # e.g. the group to which this fact applies
    specifier_type: Literal["a", "the"]  # the specifier used for the subject

    def generate_specified_subject(self, upper=False, specifier_type=None) -> str:
        """Figure out whether to use "A" or "An" (heuristic), handles "the" and "The" too
        specifier_type can be used to optionally override the class property specifier_type

        Example:
            if (subject="cat", subject_category="animal", specifier_type="the", upper=False) -> "the cat"
            if (subject="apple", subject_category="plant", specifier_type="a", upper=True) -> "An apple"
        """

        specifier_type = self.specifier_type if specifier_type is None else specifier_type
        if self.subject_category == "person" and self.subject != "person":
            return self.subject
        base_char = specifier_type[0].upper() if upper else specifier_type[0].lower()
        if specifier_type == "a":
            if self.subject[0].lower() in ["a", "e", "i", "o", "u"]:
                return f"{base_char}n {self.subject}"
        else:
            return f"{base_char}{specifier_type[1:]} {self.subject}"
        return f"{base_char} {self.subject}"


@dataclass(frozen=True)
class LanguageRule(LanguageLogicalStatement):
    """Class describing how a set of attributes about an individual/group imply another attribute."""

    condition: List[str]  # a list of attributes which must apply for the rule to apply
    condition_conjunction: Literal["and", "or"]  # "and" or "or", corresponding to
    consequent: str  # the attribute resulting from the application of the rule

    def __str__(self) -> str:
        """Renders the rule, i.e. corresponding to "if x (and/or y) then z"

        Rules should have the following format:
        {
            'subject': 'Alice',
            'subject_category': 'person',
            'specifier_type': 'the' or 'a'
            'condition': ['red', 'kind'],
            'condition_conjunction': 'and',
            'consequent': 'cold'
        }

        and this example will output a string: "If Alice is red and kind, then Alice is cold."
        """

        condition = f" {self.condition_conjunction} ".join(self.condition)
        specified_subject = self.generate_specified_subject()
        specified_particular_subject = self.generate_specified_subject(specifier_type="the")
        if len(self.condition)==1:
            return f"{condition.capitalize()} {self.subject} is {self.consequent}."
        elif len(self.condition)==3 and self.condition_conjunction == "or":
            strs = []
            for cond in self.condition:
                strs.append(f"{cond.capitalize()} {self.subject} is {self.consequent}.")
            return " ".join(strs)
        else:
            return f"If {specified_subject} is {condition}, then {specified_particular_subject} is {self.consequent}."
            

    
    def as_dict(self) -> dict:
        """Return the structured representation as a dictionary."""
        condition = f" {self.condition_conjunction} ".join(self.condition)
        specified_subject = self.generate_specified_subject()
        specified_particular_subject = self.generate_specified_subject(specifier_type="the")
        return {
            "subject": specified_subject,
            "condition": self.condition,
            "condition_conjunction": self.condition_conjunction,
            "consequent": self.consequent,
            "text": f"If {specified_subject} is {condition}, then {specified_particular_subject} is {self.consequent}."
        }


def expand_three_condition_and_rule(
    base_rule: LanguageRule,
    *,
    attribute_pool: Optional[Dict[str, List[str]]] = None,
) -> List[LanguageRule]:
    """Given a LanguageRule with exactly three conditions joined by "and",
    generate three additional rules by negating one condition at a time and
    assigning fresh consequents.

    Input form (example):
        If X is A and B and C, then X is D.

    Output rules (returned together with the original):
        If X is A and B and not C, then X is E.
        If X is A and C and not B, then X is F.
        If X is B and C and not A, then X is G.
    """

    if base_rule.condition_conjunction != "and":
        raise ValueError("Base rule must use 'and' as condition_conjunction.")
    if len(base_rule.condition) != 3:
        raise ValueError("Base rule must have exactly three conditions.")

    pool = attribute_pool if attribute_pool is not None else globals().get("attribute_groups")
    if not pool:
        raise ValueError("Attribute pool is empty or unavailable.")

    A, B, C = base_rule.condition
    used_attrs = set([A, B, C, base_rule.consequent])

    candidates = [a for a in pool.keys() if a not in used_attrs]
    if len(candidates) < 3:
        raise ValueError("Not enough distinct attributes available to assign consequents E, F, and G.")

    E, F, G = random.sample(candidates, 3)

    mk_rule = lambda cond, cons: LanguageRule(
        subject=base_rule.subject,
        subject_category=base_rule.subject_category,
        specifier_type=base_rule.specifier_type,
        condition=cond,
        condition_conjunction="and",
        consequent=cons,
    )

    rule_E = mk_rule([A, B, f"not {C}"], E)
    rule_F = mk_rule([A, C, f"not {B}"], F)
    rule_G = mk_rule([B, C, f"not {A}"], G)

    return [base_rule, rule_E, rule_F, rule_G]


@dataclass(frozen=True)
class LanguageFact(LanguageLogicalStatement):
    """Class describing a statement that a subject has some attributes."""

    specific_attributes: List[str]
    generic_attributes: List[str]
    use_specific_attributes: bool
    upper: bool = True

    def __str__(self) -> str:
        """Return a human-readable string version of the fact."""
        if len(self.generic_attributes) == 0:
            return "None of the conclusions can be inferred."
        target_attributes = self.specific_attributes if self.use_specific_attributes else self.generic_attributes
        specified_subject = self.generate_specified_subject(upper=self.upper)
        return f"{specified_subject} is {' and '.join(target_attributes)}."

    def as_dict(self) -> dict:
        """Return the structured representation as a dictionary."""
        if len(self.generic_attributes) == 0:
            return {"subject": None, "attributes": [], "text": "Nothing."}
        target_attributes = self.specific_attributes if self.use_specific_attributes else self.generic_attributes
        specified_subject = self.generate_specified_subject(upper=self.upper)
        return {
            "subject": specified_subject,
            "attributes": target_attributes,
            "text": f"{specified_subject} is {' and '.join(target_attributes)}."
        }



def get_vocab() -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
    # A list of subjects and their categories
    subjects: Dict[str, List[str]] = {
        "person": ["Alice", "Bob", "Carol", "Dan", "Erin", "Frank", "George", "Harry", "Iris", "Jack", "Kevin", "Lance", "Miller"],
        "animal": [
            "dog",
            "cat",
            "rabbit",
            "mouse",
            "tiger",
            "lion",
            "bear",
            "squirrel",
            "cow",
            "panda",
            "hedgehog",
            "elephant",
            "giraffe",
            "hippo",
        ],
        "fruit": ["apple", "banana", "orange", "grape", "strawberry", "blueberry", "watermelon", "pineapple", "mango", "peach", "cherry", "pear", "kiwi", "lemon", "plum"],
    }

    attribute_groups = {
        "young": ["young"],
        "soft": ["soft"],
        "scary": ["scary"],
        "hot": ["hot"],
        "smart": ["smart"],
        "clean": ["clean"],
        "beautiful": ["beautiful"],
        "red": ["red"],
        "blue": ["blue"],
        "green": ["green"],
        "purple": ["purple"],
        "boring": ["boring"],
        "strong": ["strong"],
        "happy": ["happy"],
        "round": ["round"],
        "big": ["big"],
        "noisy": ["noisy"],
        "fast": ["fast"],
        "sticky": ["sticky"],
        "bouncy": ["bouncy"],
        "spiky": ["spiky"],
        "furry": ["furry"],
        "bright": ["bright"],
        "shiny": ["shiny"],
        "magical": ["magical"],
        "striped": ["striped"],
        "spotted": ["spotted"],
        "tasty": ["tasty"],
        "juicy": ["juicy"],
        "toxic": ["toxic"],
        "friendly": ["friendly"],
        "curious": ["curious"],
        "loud": ["loud"],
        "sleepy": ["sleepy"]
    }
    # Remove any keys which duplicate subitems
    new_attribute_groups: Dict[str, List[str]] = copy(attribute_groups)
    for general_attribute, specific_attributes in attribute_groups.items():
        for specific_attribute in specific_attributes:
            if (general_attribute != specific_attribute) and (specific_attribute in attribute_groups):
                del new_attribute_groups[specific_attribute]

    return new_attribute_groups, subjects


def generate_rules(
    attribute_groups: Dict[str, List[str]],
    subject: str,
    subject_category: str,
    max_rules: int = 8,
    specific_category: bool = False,
    use_and: bool = False,
    use_or: bool = False,
) -> List[LanguageRule]:
    """Generates a random set of rules about a subject as dictionaries,
    given a list of potential attributes and the category (e.g. person) of the subject (e.g. Alice)

    These rules are guaranteed to not contradict one another, and attributes implied by a single rule will
    not imply any attributes in any other rules (i.e. there is only a single step of reasoning).
    """
    attributes_shuffled = list(attribute_groups.keys()).copy()
    random.shuffle(attributes_shuffled)
    rules: List[LanguageRule] = []

    first_rule = True
    while len(attributes_shuffled) > 2 and len(rules) < max_rules:
        rule_subject = subject if specific_category else random.choice([subject_category, subject])
        if first_rule and (use_and or use_or):
            n_rule_attributes = 4
        else:
            n_rule_attributes = random.choices([2, 3], weights=[0.7, 0.3])[0]
        rule_attributes, attributes_shuffled = (
            attributes_shuffled[:n_rule_attributes],
            attributes_shuffled[n_rule_attributes:],
        )
        if first_rule and use_and:
            rules.append(
                LanguageRule(
                    subject=rule_subject,
                    subject_category=subject_category,
                    specifier_type="a",
                    condition=rule_attributes[:-1],
                    condition_conjunction="and",
                    consequent=rule_attributes[-1],
                )
            )
        elif first_rule and use_or:
            rules.append(
                LanguageRule(
                    subject=rule_subject,
                    subject_category=subject_category,
                    specifier_type="a",
                    condition=rule_attributes[:-1],
                    condition_conjunction="or",
                    consequent=rule_attributes[-1],
                )
            )
        else:
            rules.append(
                LanguageRule(
                    subject=rule_subject,
                    subject_category=subject_category,
                    specifier_type="a",
                    condition=rule_attributes[:-1],
                    condition_conjunction = random.choices(["or"], weights=[1.0])[0],
                    consequent=rule_attributes[-1],
                )
            )
        first_rule = False
    return rules


def generate_test(
    attribute_groups: Dict[str, List[str]],
    subject: str,
    subject_category: str,
    rules: List[LanguageRule],
    use_specific_attributes: bool,
    excluded_attributes=set(),
    p_consequenceless=0.0,
    num_attrs = 2
) -> Tuple[LanguageFact, List[LanguageRule], LanguageFact]:
    """Generates a test case given a set of rules, i.e. a statement about the subject from which something
    can be potentially deduced given the rules. We include an argument, p_consequenceless, to re-roll with
    some probability if the generated fact does not allow anything to be determined.
    """

    max_attempts = 50
    excluded_attributes = excluded_attributes or set()
    attempt = 0
    while attempt < max_attempts:
        attempt += 1

        # 1) pick two fresh attributes for the test fact --------------------
        test_attrs: List[str] = random.sample(
            list(attribute_groups.keys() - excluded_attributes), num_attrs
        )
        test_attrs_specific: List[str] = [
            random.choice(attribute_groups[a]) for a in test_attrs
        ]

        # 2) infer consequents via rules ------------------------------------
        test_consequents: List[str] = []
        test_rules_used: List[LanguageRule] = []

        for rule in rules:
            if rule.consequent in test_attrs:
                continue  # already stated

            if rule.condition_conjunction == "and" and set(rule.condition).issubset(test_attrs):
                test_rules_used.append(rule)
                test_consequents.append(rule.consequent)

            elif rule.condition_conjunction == "or" and not set(rule.condition).isdisjoint(test_attrs):
                test_rules_used.append(rule)
                test_consequents.append(rule.consequent)
        if len(test_consequents) > 1:
            test_consequents = [test_consequents[0]]

        # 3a) reject if any consequent is in excluded_attr_set --------------
        if test_consequents and test_consequents[0] in excluded_attributes:
            continue

        # 3b) reject (optionally) if nothing can be inferred ----------------
        if not test_consequents and random.random() > p_consequenceless:
            continue

        # success
        break
    else:
        # after max_attempts we fall back to an empty consequent list
        test_consequents, test_rules_used = [], []

    # 4) build LanguageFact objects ----------------------------------------
    
    test_fact = LanguageFact(
        subject                 = subject,
        subject_category        = subject_category,
        specifier_type          = "the",
        specific_attributes     = test_attrs_specific,
        generic_attributes      = test_attrs,
        use_specific_attributes = use_specific_attributes,
    )

    target_fact = dataclasses.replace(
        test_fact,
        specific_attributes = test_consequents,
        generic_attributes  = test_consequents,
    )

    return test_fact, test_rules_used, target_fact


def generate_distractor_facts(
    subject_category: str,
    all_subjects: Dict[str, List[str]],
    attribute_groups: Dict[str, List[str]],
    used_subjects: List[str],
    used_attributes: List[str],
    min_num_distractors: int = 2,
    max_num_distractors: int = 5,
    use_specific_attributes: bool = False,
) -> List[LanguageFact]:
    """Generate distraction facts by combining used/distractor subjects with used/distractor attributes."""
    
    distractor_facts: List[LanguageFact] = []

    num_distractors = random.randint(min_num_distractors, max_num_distractors)
    
    # All possible subjects in this category
    all_subjects_in_category = all_subjects[subject_category]
    distractor_subjects = list(set(all_subjects_in_category) - set(used_subjects))
    used_subjects = list(set(used_subjects) - set(['fruit','person','animal']))

    all_attributes = list(attribute_groups.keys())
    distractor_attributes = list(set(all_attributes) - set(used_attributes))

    while len(distractor_facts) < num_distractors:
        # Decide subject type
        if random.random() < 0.5 and used_subjects:
            subject = random.choice(used_subjects)
        else:
            if distractor_subjects:
                subject = random.choice(distractor_subjects)
            else:
                continue  # skip if no distractor subjects available
        
        # Decide attribute source
        attr_source_pool = []
        if subject in used_subjects:
            attr_source_pool = distractor_attributes
        else:
            # Mix used and distractor attributes for more confusion
            attr_source_pool = random.sample(used_attributes + distractor_attributes, k=min(5, len(used_attributes + distractor_attributes)))

        if not attr_source_pool:
            continue  # skip if no available attributes

        n_attrs = 1
        generic_attrs = random.sample(attr_source_pool, k=min(n_attrs, len(attr_source_pool)))
        specific_attrs = [random.choice(attribute_groups[attr]) for attr in generic_attrs]
        
        fact = LanguageFact(
            subject=subject,
            subject_category=subject_category,
            specifier_type="the",
            specific_attributes=specific_attrs,
            generic_attributes=generic_attrs,
            use_specific_attributes=use_specific_attributes,
        )
        distractor_facts.append(fact)

    return distractor_facts

attribute_groups, subjects = get_vocab()

def generate_problem_category1() -> Tuple:
    # generate problems for independent, equivalence and contradictory setting 
    subject_category = random.choice(list(subjects.keys()))
    subject = random.choice(subjects[subject_category])
    
    rules = generate_rules(attribute_groups, subject, subject_category, max_rules=4, specific_category=False)

    test_facts, target_facts, excluded_attributes = [], [], set()

    for _ in range(4):
        test_fact, _, target_fact = generate_test(
            attribute_groups, subject, subject_category, rules,
            use_specific_attributes=False,
            excluded_attributes=excluded_attributes,
        )
        test_facts.append(test_fact)
        target_facts.append(target_fact)
        excluded_attributes.update(test_fact.generic_attributes)
        excluded_attributes.update(target_fact.generic_attributes)


    used_subjects = {subject}
    used_attributes = set()
    for fact in test_facts:
        used_attributes.update(fact.generic_attributes)
    for rule in rules:
        if rule.subject not in sum(subjects.values(), []):
            used_subjects.add(rule.subject)
        used_attributes.update(rule.condition)
        used_attributes.add(rule.consequent)

    distractor_facts_group = []
    for _ in range(3):
        distractor_group = generate_distractor_facts(
            subject_category=subject_category,
            all_subjects=subjects,
            attribute_groups=attribute_groups,
            used_subjects=list(used_subjects),
            used_attributes=list(used_attributes),
        )
        distractor_facts_group.append(distractor_group)

    return {
            "rules": rules,
            "test_facts": test_facts,
            "target_facts": target_facts,
            "distractor_groups": distractor_facts_group,
        },{
            "rules": [str(rule) for rule in rules],
            "test_facts": [str(fact) for fact in test_facts],
            "target_facts": [str(fact) for fact in target_facts],
            "distractor_groups": [[str(fact) for fact in group] for group in distractor_facts_group]
        }
        

def generate_problem_category2(type) -> Dict:
    # generate problems for alternative and complementary setting 
    subject_category = random.choice(list(subjects.keys()))
    subject = random.choice(subjects[subject_category])
    
    if type == "alternative":
        rules = generate_rules(attribute_groups, subject, subject_category, specific_category=False, use_or=True)
    elif type == "complementary":
        rules = generate_rules(attribute_groups, subject, subject_category, specific_category=False, use_and=True)

    test_facts, target_facts, excluded_attributes = [], [], set()

    alternative_test_facts = []

    all_attributes = list(attribute_groups.keys())
    distractor_attributes = list(set(all_attributes) - set(rules[0].as_dict()['condition']))

    for attribute in rules[0].as_dict()['condition']:
        if random.random() < 0.5:
            other_attribute = random.choice(distractor_attributes)
            alternative_test_facts.append(LanguageFact(
                subject,
                subject_category,
                specifier_type="the",
                specific_attributes=[attribute, other_attribute],
                generic_attributes=[attribute, other_attribute],
                use_specific_attributes=False,
            ))
            excluded_attributes.update([attribute, other_attribute])
        else:
            alternative_test_facts.append(LanguageFact(
                subject,
                subject_category,
                specifier_type="the",
                specific_attributes=[attribute],
                generic_attributes=[attribute],
                use_specific_attributes=False,
            ))
            excluded_attributes.update(attribute)


    test_facts.append(alternative_test_facts)
    target_fact = LanguageFact(
            subject,
            subject_category,
            specifier_type="the",
            specific_attributes=[rules[0].as_dict()['consequent']],
            generic_attributes=[rules[0].as_dict()['consequent']],
            use_specific_attributes=False,
        )
    target_facts.append(target_fact)
    excluded_attributes.update(target_fact.generic_attributes)

    for _ in range(3):
        test_fact, _, target_fact = generate_test(
            attribute_groups, subject, subject_category, rules,
            use_specific_attributes=False,
            excluded_attributes=excluded_attributes,
        )
        test_facts.append(test_fact)
        target_facts.append(target_fact)
        excluded_attributes.update(test_fact.generic_attributes)
        excluded_attributes.update(target_fact.generic_attributes)

    used_subjects = {subject}
    used_attributes = set()
    
    for fact in test_facts[0]:
        used_attributes.update(fact.generic_attributes)
    for fact in test_facts[1:]:
        used_attributes.update(fact.generic_attributes)
    for rule in rules:
        if rule.subject not in sum(subjects.values(), []):
            used_subjects.add(rule.subject)
        used_attributes.update(rule.condition)
        used_attributes.add(rule.consequent)

    distractor_facts_group = []
    for _ in range(3):
        distractor_group = generate_distractor_facts(
            subject_category=subject_category,
            all_subjects=subjects,
            attribute_groups=attribute_groups,
            used_subjects=list(used_subjects),
            used_attributes=list(used_attributes),
        )
        distractor_facts_group.append(distractor_group)

    return {
            "rules": rules,
            "test_facts": test_facts,
            "target_facts": target_facts,
            "distractor_groups": distractor_facts_group,
        },{
            "rules": [str(rule) for rule in rules],
            "test_facts": [[str(fact) for fact in test_facts[0]]] + [str(fact) for fact in test_facts[1:]],
            "target_facts": [str(fact) for fact in target_facts],
            "distractor_groups": [[str(fact) for fact in group] for group in distractor_facts_group]
        }

def generate_problem_category3() -> Tuple:
    # generate problems for entailment setting 
    subject_category = random.choice(list(subjects.keys()))
    subject = random.choice(subjects[subject_category])
    
    rules = generate_rules(attribute_groups, subject, subject_category, specific_category=False)

    used_subjects = {subject}
    used_attributes = set()
    for rule in rules:
        if rule.subject not in sum(subjects.values(), []):
            used_subjects.add(rule.subject)
        used_attributes.update(rule.condition)
        used_attributes.add(rule.consequent)

    test_facts, target_facts, excluded_attributes = [], [], set()
    multi_step_facts = []

    multi_step_fact, _, target_fact = generate_test(
        attribute_groups, subject, subject_category, rules,
        use_specific_attributes=False,
        excluded_attributes=excluded_attributes,
        num_attrs=1
    )
    multi_step_facts.append(multi_step_fact)
    target_facts.append(target_fact)
    excluded_attributes.update(multi_step_fact.generic_attributes)
    excluded_attributes.update(target_fact.generic_attributes)
    for _ in range(2):
        test_attr = random.sample(list(attribute_groups.keys() - used_attributes), 1)[0]
        rules.append(
            LanguageRule(
                subject=multi_step_facts[-1].subject,
                subject_category=multi_step_facts[-1].subject_category,
                specifier_type="a",
                condition=[test_attr],
                condition_conjunction="or",
                consequent=multi_step_facts[-1].generic_attributes[0],  # use the first attribute from the previous fact
            )
        )
        multi_step_facts.append(
            LanguageFact(
                subject                 = subject,
                subject_category        = subject_category,
                specifier_type          = "the",
                specific_attributes     = [test_attr],
                generic_attributes      = [test_attr],
                use_specific_attributes = False,
            )
        )
        excluded_attributes.update([test_attr])

    test_facts.append(multi_step_facts)
    for _ in range(3):
        test_fact, _, target_fact = generate_test(
            attribute_groups, subject, subject_category, rules,
            use_specific_attributes=False,
            excluded_attributes=excluded_attributes,
        )
        test_facts.append(test_fact)
        target_facts.append(target_fact)
        excluded_attributes.update(test_fact.generic_attributes)
        excluded_attributes.update(target_fact.generic_attributes)

    for fact in test_facts[0]:
        used_attributes.update(fact.generic_attributes)

    for fact in test_facts[1:]:
        used_attributes.update(fact.generic_attributes)
    
    distractor_facts_group = []
    for _ in range(3):
        distractor_group = generate_distractor_facts(
            subject_category=subject_category,
            all_subjects=subjects,
            attribute_groups=attribute_groups,
            used_subjects=list(used_subjects),
            used_attributes=list(excluded_attributes),
        )
        distractor_facts_group.append(distractor_group)

    return {
            "rules": rules,
            "test_facts": test_facts,
            "target_facts": target_facts,
            "distractor_groups": distractor_facts_group,
        },{
            "rules": [str(rule) for rule in rules],
            "test_facts": [[str(fact) for fact in test_facts[0]]] + [str(fact) for fact in test_facts[1:]],
            "target_facts": [str(fact) for fact in target_facts],
            "distractor_groups": [[str(fact) for fact in group] for group in distractor_facts_group]
        }

def generate_problem_category1_no_noise() -> Tuple:
    # generate problems for independent, equivalence and contradictory setting with no noise
    subject_category = random.choice(list(subjects.keys()))
    subject = random.choice(subjects[subject_category])
    
    rules = generate_rules(attribute_groups, subject, subject_category, max_rules=4, specific_category=False)

    test_facts, target_facts, excluded_attributes = [], [], set()

    for rule in rules:
        test_attrs = random.sample(rule.condition,1)
        test_fact = LanguageFact(
                subject                 = subject,
                subject_category        = subject_category,
                specifier_type          = "the",
                specific_attributes     = test_attrs,
                generic_attributes      = test_attrs,
                use_specific_attributes = False,
            )
        target_fact = LanguageFact(
                subject                 = subject,
                subject_category        = subject_category,
                specifier_type          = "the",
                specific_attributes     = [rule.consequent],
                generic_attributes      = [rule.consequent],
                use_specific_attributes = False,
            )
        
        test_facts.append(test_fact)
        target_facts.append(target_fact)
        excluded_attributes.update(test_fact.generic_attributes)
        excluded_attributes.update(target_fact.generic_attributes)


    used_subjects = {subject}
    used_attributes = set()
    for fact in test_facts:
        used_attributes.update(fact.generic_attributes)
    for rule in rules:
        if rule.subject not in sum(subjects.values(), []):
            used_subjects.add(rule.subject)
        used_attributes.update(rule.condition)
        used_attributes.add(rule.consequent)

    distractor_facts_group = []
    for _ in range(3):
        distractor_group = generate_distractor_facts(
            subject_category=subject_category,
            all_subjects=subjects,
            attribute_groups=attribute_groups,
            used_subjects=list(used_subjects),
            used_attributes=list(used_attributes),
        )
        distractor_facts_group.append(distractor_group)

    return {
            "rules": rules,
            "test_facts": test_facts,
            "target_facts": target_facts,
            "distractor_groups": distractor_facts_group,
        },{
            "rules": [str(rule) for rule in rules],
            "test_facts": [str(fact) for fact in test_facts],
            "target_facts": [str(fact) for fact in target_facts],
            "distractor_groups": [[str(fact) for fact in group] for group in distractor_facts_group]
        }

def generate_problem_category2_no_noise() -> Dict:
    # generate problems for alternative setting with no noise
    subject_category = random.choice(list(subjects.keys()))
    subject = random.choice(subjects[subject_category])
    
    rules = generate_rules(attribute_groups, subject, subject_category, max_rules=4, specific_category=False, use_or=True)

    test_facts, target_facts, excluded_attributes = [], [], set()

    alternative_test_facts = []

    all_attributes = list(attribute_groups.keys())
    distractor_attributes = list(set(all_attributes) - set(rules[0].as_dict()['condition']))

    for attribute in rules[0].as_dict()['condition']:
        alternative_test_facts.append(LanguageFact(
            subject,
            subject_category,
            specifier_type="the",
            specific_attributes=[attribute],
            generic_attributes=[attribute],
            use_specific_attributes=False,
        ))
        excluded_attributes.update(attribute)


    test_facts.append(alternative_test_facts)
    target_fact = LanguageFact(
            subject,
            subject_category,
            specifier_type="the",
            specific_attributes=[rules[0].as_dict()['consequent']],
            generic_attributes=[rules[0].as_dict()['consequent']],
            use_specific_attributes=False,
        )
    target_facts.append(target_fact)
    excluded_attributes.update(target_fact.generic_attributes)

    for rule in rules[1:]:
        test_attrs = random.sample(rule.condition,1)
        test_fact = LanguageFact(
                subject                 = subject,
                subject_category        = subject_category,
                specifier_type          = "the",
                specific_attributes     = test_attrs,
                generic_attributes      = test_attrs,
                use_specific_attributes = False,
            )
        target_fact = LanguageFact(
                subject                 = subject,
                subject_category        = subject_category,
                specifier_type          = "the",
                specific_attributes     = [rule.consequent],
                generic_attributes      = [rule.consequent],
                use_specific_attributes = False,
            )
        
        test_facts.append(test_fact)
        target_facts.append(target_fact)
        excluded_attributes.update(test_fact.generic_attributes)
        excluded_attributes.update(target_fact.generic_attributes)

    used_subjects = {subject}
    used_attributes = set()
    
    for fact in test_facts[0]:
        used_attributes.update(fact.generic_attributes)
    for fact in test_facts[1:]:
        used_attributes.update(fact.generic_attributes)
    for rule in rules:
        if rule.subject not in sum(subjects.values(), []):
            used_subjects.add(rule.subject)
        used_attributes.update(rule.condition)
        used_attributes.add(rule.consequent)

    distractor_facts_group = []
    for _ in range(3):
        distractor_group = generate_distractor_facts(
            subject_category=subject_category,
            all_subjects=subjects,
            attribute_groups=attribute_groups,
            used_subjects=list(used_subjects),
            used_attributes=list(used_attributes),
        )
        distractor_facts_group.append(distractor_group)

    return {
            "rules": rules,
            "test_facts": test_facts,
            "target_facts": target_facts,
            "distractor_groups": distractor_facts_group,
        },{
            "rules": [str(rule) for rule in rules],
            "test_facts": [[str(fact) for fact in test_facts[0]]] + [str(fact) for fact in test_facts[1:]],
            "target_facts": [str(fact) for fact in target_facts],
            "distractor_groups": [[str(fact) for fact in group] for group in distractor_facts_group]
        }

def generate_problem_category3_no_noise() -> Tuple:
    # generate problems for entailment setting with no noise
    subject_category = random.choice(list(subjects.keys()))
    subject = random.choice(subjects[subject_category])
    
    rules = generate_rules(attribute_groups, subject, subject_category, max_rules=4, specific_category=False)

    used_subjects = {subject}
    used_attributes = set()
    for rule in rules:
        if rule.subject not in sum(subjects.values(), []):
            used_subjects.add(rule.subject)
        used_attributes.update(rule.condition)
        used_attributes.add(rule.consequent)

    test_facts, target_facts, excluded_attributes = [], [], set()

    for i, rule in enumerate(rules[:4]):
        if i != 0:
            test_attrs = random.sample(rule.condition,1)
            test_fact = LanguageFact(
                    subject                 = subject,
                    subject_category        = subject_category,
                    specifier_type          = "the",
                    specific_attributes     = test_attrs,
                    generic_attributes      = test_attrs,
                    use_specific_attributes = False,
                )
            target_fact = LanguageFact(
                    subject                 = subject,
                    subject_category        = subject_category,
                    specifier_type          = "the",
                    specific_attributes     = [rule.consequent],
                    generic_attributes      = [rule.consequent],
                    use_specific_attributes = False,
                )
            
            test_facts.append(test_fact)
            target_facts.append(target_fact)
            excluded_attributes.update(test_fact.generic_attributes)
            excluded_attributes.update(target_fact.generic_attributes)
        else:
            multi_step_facts = []
            test_attrs = random.sample(rule.condition,1)
            test_fact = LanguageFact(
                    subject                 = subject,
                    subject_category        = subject_category,
                    specifier_type          = "the",
                    specific_attributes     = test_attrs,
                    generic_attributes      = test_attrs,
                    use_specific_attributes = False,
                )
            target_fact = LanguageFact(
                    subject                 = subject,
                    subject_category        = subject_category,
                    specifier_type          = "the",
                    specific_attributes     = [rule.consequent],
                    generic_attributes      = [rule.consequent],
                    use_specific_attributes = False,
                )
            
            multi_step_facts.append(test_fact)
            target_facts.append(target_fact)
            excluded_attributes.update(test_fact.generic_attributes)
            excluded_attributes.update(target_fact.generic_attributes)

            for _ in range(2):
                test_attr = random.sample(list(attribute_groups.keys() - used_attributes), 1)[0]
                rules.append(
                    LanguageRule(
                        subject=multi_step_facts[-1].subject,
                        subject_category=multi_step_facts[-1].subject_category,
                        specifier_type="a",
                        condition=[test_attr],
                        condition_conjunction="or",
                        consequent=multi_step_facts[-1].generic_attributes[0],  # use the first attribute from the previous fact
                    )
                )
                multi_step_facts.append(
                    LanguageFact(
                        subject                 = subject,
                        subject_category        = subject_category,
                        specifier_type          = "the",
                        specific_attributes     = [test_attr],
                        generic_attributes      = [test_attr],
                        use_specific_attributes = False,
                    )
                )
                excluded_attributes.update([test_attr])
            test_facts.append(multi_step_facts)

    for fact in test_facts[0]:
        used_attributes.update(fact.generic_attributes)

    for fact in test_facts[1:]:
        used_attributes.update(fact.generic_attributes)
    
    distractor_facts_group = []
    for _ in range(3):
        distractor_group = generate_distractor_facts(
            subject_category=subject_category,
            all_subjects=subjects,
            attribute_groups=attribute_groups,
            used_subjects=list(used_subjects),
            used_attributes=list(excluded_attributes),
        )
        distractor_facts_group.append(distractor_group)

    return {
            "rules": rules,
            "test_facts": test_facts,
            "target_facts": target_facts,
            "distractor_groups": distractor_facts_group,
        },{
            "rules": [str(rule) for rule in rules],
            "test_facts": [[str(fact) for fact in test_facts[0]]] + [str(fact) for fact in test_facts[1:]],
            "target_facts": [str(fact) for fact in target_facts],
            "distractor_groups": [[str(fact) for fact in group] for group in distractor_facts_group]
        }

def generate_problem_category4_no_noise() -> Dict:
    # generate problems for complementary setting with no noise
    subject_category = random.choice(list(subjects.keys()))
    subject = random.choice(subjects[subject_category])
    
    rules_ = generate_rules(attribute_groups, subject, subject_category, max_rules=1, specific_category=False, use_and=True)

    rules = expand_three_condition_and_rule(rules_[0])

    test_facts, target_facts, excluded_attributes = [], [], set()

    alternative_test_facts = []

    all_attributes = list(attribute_groups.keys())
    distractor_attributes = list(set(all_attributes) - set(rules[0].as_dict()['condition']))

    for attribute in rules[0].as_dict()['condition']:
        alternative_test_facts.append(LanguageFact(
            subject,
            subject_category,
            specifier_type="the",
            specific_attributes=[attribute],
            generic_attributes=[attribute],
            use_specific_attributes=False,
        ))
        excluded_attributes.update(attribute)


    test_facts.append(alternative_test_facts)
    target_fact = LanguageFact(
            subject,
            subject_category,
            specifier_type="the",
            specific_attributes=[rules[0].as_dict()['consequent']],
            generic_attributes=[rules[0].as_dict()['consequent']],
            use_specific_attributes=False,
        )
    target_facts.append(target_fact)
    excluded_attributes.update(target_fact.generic_attributes)

    for rule in rules[1:]:
        test_attrs = random.sample(rule.condition,1)
        test_fact = LanguageFact(
                subject                 = subject,
                subject_category        = subject_category,
                specifier_type          = "the",
                specific_attributes     = test_attrs,
                generic_attributes      = test_attrs,
                use_specific_attributes = False,
            )
        target_fact = LanguageFact(
                subject                 = subject,
                subject_category        = subject_category,
                specifier_type          = "the",
                specific_attributes     = [rule.consequent],
                generic_attributes      = [rule.consequent],
                use_specific_attributes = False,
            )
        
        test_facts.append(test_fact)
        target_facts.append(target_fact)
        excluded_attributes.update(test_fact.generic_attributes)
        excluded_attributes.update(target_fact.generic_attributes)

    used_subjects = {subject}
    used_attributes = set()
    
    for fact in test_facts[0]:
        used_attributes.update(fact.generic_attributes)
    for fact in test_facts[1:]:
        used_attributes.update(fact.generic_attributes)
    for rule in rules:
        if rule.subject not in sum(subjects.values(), []):
            used_subjects.add(rule.subject)
        used_attributes.update(rule.condition)
        used_attributes.add(rule.consequent)

    return {
            "rules": rules,
            "test_facts": test_facts,
            "target_facts": target_facts,
        },{
            "rules": [str(rule) for rule in rules],
            "test_facts": [[str(fact) for fact in test_facts[0]]] + [str(fact) for fact in test_facts[1:]],
            "target_facts": [str(fact) for fact in target_facts]
        }

if __name__ == "__main__":
    # Generate 1000 samples
    dataset = []
    dataset_text = []
    cnt = 0
    for _ in tqdm(range(1000)):
        data, data_text = generate_problem_category1_no_noise()
        dataset.append(data)
        dataset_text.append(data_text)

    print(f"Successfully generated {len(dataset)} logical reasoning samples.")

    # Save dataset
    with open("/path/to/output.pkl", "wb") as f:
        pickle.dump(dataset, f)

    with open("/path/to/output.json", "w") as f:
        json.dump(dataset_text, f, indent=2)
