import sympy
import re
import math

from config import Config
from typing import List, Tuple, Union

class BinaryParser:

    def __init__(self, config:Config, epsilon:float=0.2, delta:float=0.1, is_bert:bool=False):
        self.config = config
        self.templates = self.config.string_templates
        # Allow for dynamic number of categories
        self.items = []
        self.features = self.config.features
        self.keys = self.features.keys()
        self.lengths = {}
        current_length = 0
        for key, values in self.features.items():
            current_length += len(values)
            self.lengths[key] = current_length
            self.items.extend([(key, value) for value in values])
        self.total_length = len(self.items)
        self.reasoning = False
        self.summary = False

        # Create symbolic variables for each value
        variable_string = ','.join(f'v{i}' for i in range(self.total_length))
        self.V = list(sympy.symbols(variable_string))

        # Restrict variables to create disjoint variable sets for each category
        self.background = set()
        prev_category_len = 0
        for category_len in self.lengths.values():
            self.background |= self.make_disjoint(self.V[prev_category_len:category_len])
            prev_category_len = category_len

        # Parsing to sentences
        self.parsing = self.config.string_templates['parsing']

        # Hypothesis space for masked model
        total_clauses = 1
        for values in self.features.values():
            total_clauses *= len(values)
        self.pac_hypothesis_space = int ((1/epsilon) * (math.log((math.pow(2,total_clauses) / delta), 2))) # log2 should be ln


        self.is_bert = is_bert

    def make_disjoint(self, V:list):
        """
        Create a disjoint set of variables from the list V.
        
        Parameters:
        V (list): A list of variables to be made disjoint.
        
        Returns:
        set: A set of disjoint variables.
        """
        disjoint_set = set()
        for i in range(len(V)-1):
            for j in range(i+1, len(V)):
                disjoint_set.add(~(V[i] & V[j]))
        return disjoint_set
    
    def binary_to_sentence(self, binary_str:List[int], is_positive=None, is_antecedent:bool=False, is_consequent:bool=False, is_masked:bool=False) -> str:
        # String cannot be both antecedent and consequent
        assert not(is_antecedent and is_consequent)
        
        format_keys = [key.replace(' ', '_') for key in self.keys]
        string_builder = {key: '' for key in format_keys}
        
        for idx, value in enumerate(binary_str):
            if value == 1:
                category, value = self.items[idx]
                category = category.replace(' ', '_')

                string_builder[category] = self.parsing[category]['given'].format(value=value) #{self.parsing[value][value]}{self.items[idx][1]}"
                if is_antecedent or is_consequent: 
                    string_builder[category] = string_builder[category].replace('that ', '')
        
        for key in format_keys:
            if string_builder[key] == '':
                string_builder[key] = self.parsing[key]['default']

        if is_masked: 
            if self.is_bert:
                string_builder['gender'] = "<mask>"
            else:
                string_builder['gender'] = ", their gender is <mask>."
        string_key = 'antecedent' if is_antecedent else ('consequent' if is_consequent else 'sentence')
        string_key = 'mask' if self.is_bert else string_key
        sentence = self.parsing[string_key].format(**string_builder)
        if type(is_positive) is bool:
            if is_positive:
                sentence = sentence + " is possible"
            else:
                sentence = sentence + " is not possible"

        return sentence
    
    def binary_to_MQ(self, binary_str:List[int]) -> str:

        
        format_keys = [key.replace(' ', '_') for key in self.keys]
        string_builder = {key: '' for key in format_keys}
        
        for idx, value in enumerate(binary_str):
            if value == 1:
                category, value = self.items[idx]
                category = category.replace(' ', '_')

                string_builder[category] = self.parsing[category]['mq'].format(value=value)
                
        for key in format_keys:
            if string_builder[key] == '':
                string_builder[key] = self.parsing[key]['default']

        sentence = self.parsing['MQ_sentence'].format(**string_builder)

        return sentence

    def index_to_name(self, index:int):
        return self.items[index][1]
    
    def sentence_to_binary(self, sentence:str):
        # Special handling for time period to capture ranges
        year_pattern = r'\b(?:before|after)\b\s*(\d{4})|\bbetween\b\s*(\d{4})\s*and\s*(\d{4})'
        year_match = re.search(year_pattern, sentence, re.IGNORECASE)
        if year_match:
            between_pattern = r'\bbetween\b\s*(\d{4})\s*and\s*(\d{4})'
            between_match = re.search(between_pattern, sentence, re.IGNORECASE)
            if between_match:
                sentence = re.sub(between_pattern, r'between the years \1 and \2', sentence, flags=re.IGNORECASE)
            else:
                before_after_pattern = r'\b(before|after)\b\s*(\d{4})'
                sentence = re.sub(before_after_pattern, r'\1 the year \2', sentence, flags=re.IGNORECASE)
        
        bin_str = []
        valid_flag = True
        for (_, features) in self.features.items():
            feat_bin = [0] * len(features)
            for i, item in enumerate(features):
                item_match = re.search(r'\b' + re.escape(item) + r'\b', sentence, re.IGNORECASE)
                if item_match:
                    feat_bin[i] = 1
            if sum(feat_bin) > 1:
                valid_flag = False
            bin_str.append(feat_bin)
        counterexample = sum(bin_str,[])
        if sum(counterexample) == 0: 
            valid_flag = False
        return (counterexample, valid_flag)
    
    def binary_category(self, binary_sentence:List[int], is_positive:bool) -> str:
        return_str = []
        counter = 0
        for idx, binary_value in enumerate(binary_sentence):
            
            (key, value) = self.items[idx]
            if binary_value == 1: 
                return_str.append(f"{value} as the <{key}>.")
                counter += 1
        if counter < len(self.keys):
            for key in self.keys:
                if f'{key}' not in ' '.join(return_str):
                    return_str.append(f"The counterexample has no <{key}>.")
        if is_positive:
            return_str.append("This combination is possible in the real world.")
        else:
            return_str.append("This combination is not possible in the real world.")
        return "\n".join(return_str)

    
class EquationParser:
    """
    Parses sympy equations to a human readable format.

    Args:
        binary_parser: BinaryParser 
            Contains all binary features we are solving for.
        V: list              
            All the sympy variables used.

    """
    
    def __init__(self, binary_parser: BinaryParser):
        
        variable_values = [value for (key, value) in binary_parser.items]
        self.mapping = {f'{binary_parser.V[i]}': str(variable_values[i]).replace(' ', '_') for i in range(len(binary_parser.V))}
        self.binary_parser = binary_parser


    def parse(self, equation):
        if type(equation) is sympy.Implies:
            antecedent, consequent = equation.args
            ant = antecedent.args if not(len(antecedent.args) == 0) else tuple([antecedent])
            consequent = list(set(consequent.args).difference(set(ant)))
            equation = sympy.Implies(antecedent, sympy.And(*consequent))
        return sympy.pretty(equation.subs(self.mapping))
    
    def symbol_to_binary(self, clause:Union[sympy.core.Symbol, sympy.Implies, sympy.And]) -> Tuple[List[int], List[int]] | List[int]:
        
        return_binary = [0]*self.binary_parser.total_length

        if type(clause) is sympy.core.symbol.Symbol:
            return_binary[int(str(clause).replace('v', ''))] = 1
            return return_binary
        
        
        if type(clause) is sympy.Implies:
            antecedent, consequent = clause.args
            ant = list(antecedent.args) if not(len(antecedent.args) == 0) else [antecedent]
            consequent = list(set(consequent.args).difference(set(ant)))

            binary_ant = self.symbol_to_binary(sympy.And(*ant))
            binary_con = self.symbol_to_binary(sympy.And(*consequent))
            assert isinstance(binary_ant, list) and isinstance(binary_con, list)
            
            return (binary_ant, binary_con)
        
        if type(clause.args) is tuple:
            if type(clause.args[0]) is sympy.And:
                args = clause.args[0].args
            else:
                args = clause.args
            for arg in args:
                return_binary[int(str(arg).replace('v', ''))] = 1
            return return_binary
        
        return_binary[int(str(clause.args[0]).replace('v', ''))] = 1

        return return_binary


