import re
import json
import random
from typing import List, Dict, Tuple
from copy import deepcopy
from itertools import product

import numpy as np
from tqdm.auto import tqdm

from .prover9_interface import prover9_prove, FOL2Prover9Converter
from .my_dataclass import Problems


class ProblemGenerator:
    
    def __init__(self, args) -> None:
        self.args = args
        
        # problem property
        self.reasoning_depth = {
            'easy': [1, 2],
            'medium': [3, 4, 5],
            'hard': [6, 7],
            'ultra_hard': [6, 7, 8, 9]
        }
        
        # goal value property
        self.goal_value_dict = {0: 'True', 1: 'False', 2: 'Uncertain'}
        self.goal_value_probs = args.goal_value_probs
        
        # rule generator
        self.rule_generator = RuleGenerator(args)
        
    def generate_problems(self, verbose: bool = False) -> List[Problems]:
        problems = []
        problem_generation_bar = tqdm(range(self.args.num))
        
        while len(problems) < self.args.num:
            # problem property
            rule_id = 0
            fact_num = 0
            problem_rules = []
            
            # goal property
            goal_value = self.sample_goal_value()
            rule_as_goal = np.random.choice(
                a=np.array([0, 1]),
                size=1,
                replace=True,
                p=self.args.rule_as_goal_proportion
            )
            
            # generate root rule
            if rule_as_goal == 0:
                rule_expression, rule_requirement, goal_index, fact_num = self.rule_generator.generate_normal_rule(goal_expression="", gvalue=goal_value, fact_num=fact_num)
                
                if rule_expression is None:
                    continue
                
                root_rule = {'id': rule_id, 'next_rule': None, 'expression': rule_expression, 'value': self.get_fact_value(rule_expression, rule_requirement, goal_index), 'conclusion': {'index': goal_index, 'expression': self.extract_facts(rule_expression)[goal_index], 'value': goal_value}}
                
                # get reasoning step and final goal
                problem_answers = [self.get_single_deduction_step(rule_expression=rule_expression, rule_requirement=rule_requirement, conclusion_value=root_rule['conclusion']['value'], goal_index=goal_index)]
                step_cnt = 1
            else:
                rule_expression, rule_requirement, fact_num = self.rule_generator.generate_goal_rule(gvalue=goal_value, fact_num=fact_num)
                root_rule = {'id': rule_id, 'next_rule': None, 'expression': rule_expression, 'value': self.get_fact_value(rule_expression, rule_requirement), 'conclusion': {'index': None, 'expression': rule_expression, 'value': goal_value}}
                
                # get reasoning step and final goal
                problem_answers = [self.get_single_deduction_step(rule_expression=rule_expression, rule_requirement=rule_requirement, conclusion_value=root_rule['conclusion']['value'])]
                step_cnt = 0
                
            # get the goal of the problem
            final_goal = root_rule['conclusion']
                
            problem_rules.append(root_rule)
            root_rule_facts = self.extract_facts(rule_expression)
            facts_pool = [{'rule': rule_id, 'fact_position': i, 'expression': root_rule_facts[i], 'value': root_rule['value'][i]} for i in root_rule['value'].keys()]
            rule_id += 1
            
            # generate remaining rules
            steps = random.sample(self.reasoning_depth[self.args.mode], 1)[0]
            final_facts = []
            dead_end = False
            
            while step_cnt < steps:
                # randomly select a fact for proving
                goal = facts_pool.pop(random.sample(list(range(len(facts_pool))), 1)[0])
                
                if (len(facts_pool) >= self.args.fact_num_threshold) and (np.random.uniform() < self.args.fact_num_prob):
                    final_facts.append(goal)
                    continue
                else:
                    rule_expression, rule_requirement, goal_index, fact_num = self.rule_generator.generate_normal_rule(goal_expression=goal['expression'], gvalue=goal['value'], fact_num=fact_num)
                    
                    if rule_expression is None:
                        dead_end = True
                        break
                    
                    # generate distracting rule
                    distracting_rule_expression, distracting_rule_requirement, distracting_goal_index = self.rule_generator.generate_distracting_rules(goal_expression=goal['expression'])
                    distracting_rule = {'expression': distracting_rule_expression, 'value': self.get_fact_value(distracting_rule_expression, distracting_rule_requirement, distracting_goal_index)}
                    
                    # create current_rule
                    current_rule_facts = self.extract_facts(rule_expression)
                    current_rule = {'id': rule_id, 'next_rule': goal['rule'], 'expression': rule_expression, 'value': self.get_fact_value(rule_expression, rule_requirement, goal_index), 'conclusion': {'index': goal_index, 'expression': current_rule_facts[goal_index], 'value': goal['value']}, 'distracting_rule': distracting_rule}
                    
                    problem_rules.append(current_rule)
                    facts_pool.extend([{'rule': rule_id, 'fact_position': i, 'expression': current_rule_facts[i], 'value': current_rule['value'][i]} for i in current_rule['value'].keys()])
                    rule_id += 1

                    # get problem answer
                    problem_answers.insert(0, self.get_single_deduction_step(rule_expression=rule_expression, rule_requirement=rule_requirement, conclusion_value=current_rule['conclusion']['value'], goal_index=goal_index))

                    step_cnt += 1
            
            if dead_end:
                continue
            # save problem
            final_facts.extend(facts_pool)
            problems.append(Problems(
                id=len(problems),
                goal=final_goal,
                facts=final_facts,
                rules=problem_rules,
                reasoning_chain=problem_answers,
            ))
            
            # print current problems
            if verbose:
                print("Facts:")
                for item in problems[-1].facts:
                    print(f"{item['expression']}: {item['value']}")
                    
                print("\nRules:")
                for item in problems[-1].rules:
                    print(f"{item['expression']} |-> {item['conclusion']['expression']} | {item['conclusion']['value']}")
                    
                print("\nConclusion:")
                print(f"{problems[-1].goal['expression']}: {problems[-1].goal['value']}")
                
                print("\nReasoning Chain:")
                print("==========================")
                for item in problems[-1].reasoning_chain:
                    print("fact: ", end="")
                    for fact in item['facts']:
                        print(f"{fact['expression']}|{fact['value']}", end=' ')
                    print(f"\nrule: {item['rule']}")
                    print(f"conclusion: {item['conclusion']['expression']} | {item['conclusion']['value']}")
                    print("==========================")
            
            problem_generation_bar.update(1)
            
        return problems
            
            
    def sample_goal_value(self) -> str:
        goal_value = np.random.choice(
            a=np.array([0, 1, 2]),
            size=1,
            replace=True,
            p=self.goal_value_probs
        )
        
        return self.goal_value_dict[goal_value[0]]
    
    def get_single_deduction_step(self, rule_expression: str, rule_requirement: list, conclusion_value: str, goal_index: int = None) -> Dict:
        current_facts_list = self.extract_facts(rule_expression)
        rule_fact = []
                    
        for i in range(len(current_facts_list)):
            if i == goal_index:
                continue
            else:
                rule_fact.append({'expression': current_facts_list[i], 'value': rule_requirement[len(rule_fact)]})
        
        single_step_deduction = {'facts': rule_fact, 'rule': rule_expression, 'conclusion': {'expression': rule_expression if goal_index is None else current_facts_list[goal_index], 'value': conclusion_value}}
        return single_step_deduction
    
    @staticmethod
    def extract_facts(rule_expression: str) -> List:
        pattern = r'\[F\d+\]'
        matches = re.findall(pattern, rule_expression)
        result = [match.strip('[]') for match in matches]
        
        return result
        
    @staticmethod
    def get_fact_value(rule: str, value_tuple: Tuple, gindex: int = None) -> Dict:
        result = {}
        if gindex is None:
            for i in range(rule.count('[F')):
                result[i] = value_tuple[i]
        else:
            value_list = []
            for item in value_tuple:
                value_list.append(item)
            value_list.insert(gindex, None)
            
            for i in range(rule.count('[F')):
                if i == gindex:
                    continue
                else:
                    result[i] = value_list[i]
        return result
    


class RuleGenerator:
    
    def __init__(self, args) -> None:
        self.args = args
        self.converter = FOL2Prover9Converter()
        self.fact_holder = ["aget(x)", "bget(x)", "cget(x)"]
        self.load_rule_candidate(args.rule_candidate_path)
        
        
    def generate_goal_rule(self, gvalue: str, fact_num: int) -> Tuple:
        solutions = []
        err_cnt = -1
        
        while solutions == []:
            err_cnt += 1
            if err_cnt >= 50:
                return None, None, None
            logic_expression = random.sample(self.goal_rules, 1)[0]
            solutions = self.__calculate_root_rule_truth_table(logic_expression, gvalue)
        
        expression_fact_num = logic_expression.count('[F]')
        fact_num_new = fact_num + expression_fact_num
        for i in range(fact_num, fact_num_new):
            logic_expression = logic_expression.replace('[F]', f'[F{fact_num + i}]', 1)
        
        value_requirements = random.sample(solutions, 1)[0]
        return logic_expression, value_requirements, fact_num_new
    
    def generate_normal_rule(self, goal_expression: str, gvalue: str, fact_num: int) -> Tuple:
        err_cnt = -1
        solutions = []
        while solutions == []:
            err_cnt += 1
            if err_cnt >= 50:
                return None, None, None, None

            logic_expression = random.sample(self.normal_rules, 1)[0]
            if self.args.mode == "ultra_hard":
                f_cnt = logic_expression.count("[F]")
                g_possible_position = list(range(1, f_cnt + 1))
                
                goal_index = random.sample([-gpp for gpp in g_possible_position], 1)[0]
            else:
                if logic_expression in ["[F] → ([F] ∧ [F])", "[F] → ([F] ∨ [F])", "[F] → ([F] ⊕ [F])"]:
                    goal_index = random.sample([-1, -2], 1)[0]
                else:
                    goal_index = -1
                    
            goal_index += logic_expression.count('[F]')
                
            solutions = self.__calculate_normal_rule_truth_table(logic_expression, gindex=goal_index, gvalue=gvalue)
        
        expression_fact_num = logic_expression.count('[F]')
        fact_num_new = fact_num + expression_fact_num if goal_expression == "" else fact_num + expression_fact_num - 1
        
        offset = 0  # control the number of the newly introduced fact
        for i in range(fact_num, fact_num + expression_fact_num):
            if i - fact_num == goal_index and goal_expression != "":
                logic_expression = logic_expression.replace('[F]', f"[{goal_expression}]", 1)
                offset += 1
            else:
                logic_expression = logic_expression.replace('[F]', f'[F{i - offset}]', 1)
        
        value_requirements = random.sample(solutions, 1)[0]
        return logic_expression, value_requirements, goal_index, fact_num_new
    
    def generate_distracting_rules(self, goal_expression: str) -> Tuple:
        err_cnt = -1
        solutions = []
        while solutions == []:
            err_cnt += 1
            if err_cnt >= 50:
                return None, None, None, None

            logic_expression = random.sample(self.normal_rules, 1)[0]
            if logic_expression in ["[F] → ([F] ∧ [F])", "[F] → ([F] ∨ [F])", "[F] → ([F] ⊕ [F])"]:
                goal_index = random.sample([-1, -2], 1)[0]
            else:
                goal_index = -1
            goal_index += logic_expression.count('[F]')
                
            solutions = self.__calculate_normal_rule_truth_table(logic_expression, gindex=goal_index, gvalue="Uncertain")
        
        expression_fact_num = logic_expression.count('[F]')
        
        offset = 0  # control the ID of the fact
        for i in range(0, expression_fact_num):
            if i == goal_index and goal_expression != "":
                logic_expression = logic_expression.replace('[F]', f"[{goal_expression}]", 1)
            else:
                logic_expression = logic_expression.replace('[F]', f'[D{offset}]', 1)
                offset += 1
        
        value_requirements = random.sample(solutions, 1)[0]
        return logic_expression, value_requirements, goal_index
        
    def __calculate_normal_rule_truth_table(self, rule: str, gindex: int, gvalue: str) -> List:
        fact_num = rule.count('[F]')
        # replace [F] with placeholder fact
        for i in range(fact_num):
            rule = rule.replace("[F]", self.fact_holder[i], 1)
        rule = f"∀x ({rule})"
        rule = self.converter.convert_expression(rule)
        
        # get truth table with the help of theorem prover
        value_table = list(product(['True', 'False', 'Uncertain'], repeat=fact_num - 1))
        truth_table = []
        # pbar = tqdm(range(len(value_table)))
        # pbar.set_description("Testing normal rule truth table")
        
        for item in value_table:
            placeholder_facts = self.fact_holder[:fact_num]
            goal_fact = placeholder_facts.pop(gindex)
            
            premises = self.assign_rule_value(item, fact_list=placeholder_facts)
            premises.insert(0, rule)
            arguments = [((f"all x. {goal_fact}"), premises)]
            
            prover9_result = self.prover9(arguments_list=arguments)
            
            # goal value is uncertain means that even if we know the rule and the remaining fact, we still cannot deduce the target's value.
            if gvalue == "Uncertain":
                rule = premises.pop(0)
                arguments = [((rule), premises)]
                check_result = self.prover9(arguments_list=arguments)
                if check_result == 'True':
                    truth_table.append(
                        {
                            'value': item,
                            'result': prover9_result
                        }
                    )
            else:
                truth_table.append(
                    {
                        'value': item,
                        'result': prover9_result
                    }
                )
            # pbar.update(1)
        
        result = []
        for item in truth_table:
            if item['result'] == gvalue:
                result.append(item['value'])
                
        return result
    
    def __calculate_root_rule_truth_table(self, rule: str, gvalue: str) -> List:
        fact_num = rule.count('[F]')
        # replace [F] with placeholder fact
        for i in range(fact_num):
            rule = rule.replace("[F]", self.fact_holder[i], 1)
        rule = f"∀x ({rule})"
        rule = self.converter.convert_expression(rule)
        
        # get truth table with the help of theorem prover
        value_table = list(product(['True', 'False', 'Uncertain'], repeat=fact_num))
        truth_table = []
        # pbar = tqdm(range(len(value_table)))
        # pbar.set_description("Testing root rule truth table")
        
        for item in value_table:
            arguments = [((rule), self.assign_rule_value(item, fact_list=self.fact_holder[:fact_num]))]
            
            prover9_result = self.prover9(arguments_list=arguments, some_is_goal=False)
            truth_table.append(
                {
                    'value': item,
                    'result': prover9_result
                }
            )
            # pbar.update(1)
        
        result = []
        for item in truth_table:
            if item['result'] == gvalue:
                result.append(item['value'])
                
        return result
        
    def load_rule_candidate(self, file_path):
        rule_candidate = self.load_json(file_path)
        self.normal_rules = rule_candidate['normal_rules']
        self.goal_rules = rule_candidate['goal_rules']
        
    @staticmethod
    def prover9(arguments_list: List, some_is_goal: bool = False) -> str:
        result1, std_out = prover9_prove(arguments_list)
        
        assert "all x" in arguments_list[0][0]
        if "all x" in arguments_list[0][0]:
            false_conclusion = f"all x. (not ({arguments_list[0][0].replace('all x. ', '')}))"
        else:
            false_conclusion = f"not ({arguments_list[0][0]})"
        
        false_arguments = [(false_conclusion, arguments_list[0][1])]
        result2, std_out = prover9_prove(false_arguments)
        if result1 == result2:
            return "Uncertain"
        elif result1 == True and result2 == False:
            return "True"
        else:
            return "False"
    
    @staticmethod
    def assign_rule_value(value: Tuple, fact_list: List) -> List[str]:
        result= []
        for i in range(len(value)):
            fact = fact_list[i]
            if value[i] == 'True':
                result.append(f"all x. ({fact})")
            elif value[i] == 'False':
                result.append(f"all x. (not ({fact}))")
            elif value[i] == 'Uncertain':
                continue
            else:
                raise ValueError(f"Unsupported value: {value[i]}")
        return result
    
    @staticmethod
    def load_json(file_path) -> Dict:
        with open(file_path, 'r') as f:
            result = json.load(f)
        return result
        
        
        
        
        
        
        

        
        
        
        
        
        
        
        
        
        
        
        
        
        
        