from sympy import *
import functools
import timeit

from collections import Counter
from typing import Tuple, List, Set, Union
from Parsers import BinaryParser, EquationParser

HypothesisElement = Union[Implies, Not]

def evaluate(clause, assignment:List[int], V) -> bool:
    """
    Evaluate the clause with variable assignment x.
    
    Returns: True or False
    """
    if type(clause) is bool : return clause 
    return clause.subs({V[i]: assignment[i] for i in range(len(V))})

def models(assignment:List[int], hypothesis:Set[HypothesisElement], V:list) -> bool:
    """
    Evaluates a given hypothesis with variable assignment x.

    Returns: True or False
    """
    for clause in hypothesis:
        if not evaluate(clause, assignment, V):
            return False
    return True

def is_subset(subset:List[int], superset:List[int]) -> bool:
    return all(x <= y for x, y in zip(subset, superset))

def intersection_of_lists(list_of_lists:List[List[int]]) -> List[int]:
    return functools.reduce(lambda x, y: [a & b for a, b in zip(x, y)], list_of_lists)

def learn_horn_envelope(ask_membership_oracle, ask_equivalence_oracle,
                        parsers:Tuple[BinaryParser, EquationParser], 
                        args, num_iterations:int=-1):
    
    # Setup
    metadata = {
        'iteration_time': [], 
        'errors': {'empty': 0, 'summary format':0, 'values': 0, 'duplicate': 0, 'validity': 0, 'membership': 0}, 
        'terminated': False,
        'sample_numbers': []
    }
    #print(metadata)
    num_queries = {'MQ': 0, 'EQ':0, 'EQ-MQ':0}

    binary_parser, eq_parser = parsers
    H = binary_parser.background.copy()
    Q = set()
    
    previous_counterexamples = []
    negative_counterexamples = []
    positive_counterexamples = []
    non_horn_counterexamples = []
    counterexample_intersections = set()

    iteration = 0

    # Learning loop
    while iteration != num_iterations:
        
        start = timeit.default_timer()
        
        # Ask for new counterexample
        (counterexample, sample_number, is_positive, errors) = ask_equivalence_oracle((H, Q, previous_counterexamples))
        metadata['sample_numbers'].append(sample_number) 
        num_queries['EQ'] += 1
        num_queries['EQ-MQ'] += sample_number

        total_counter = Counter(errors) + Counter(metadata['errors'])
        all_possible_keys = set(errors.keys()) | set(metadata['errors'].keys())
        metadata['errors'] = {key: total_counter.get(key, 0) for key in all_possible_keys}
        

        # Terminate if hypothesis is equivalent
        if type(counterexample) is bool:
            stop = timeit.default_timer()
            metadata['iteration_time'].append(stop-start)
            metadata['terminated'] = counterexample
            break
        
        previous_counterexamples.append((counterexample, is_positive))

        # Positive Counterexample
        if is_positive: 
            positive_counterexamples.append(counterexample)
        
        # Negative Counterexample
        else:
            example_replaced = False
            for neg_example in negative_counterexamples:

                # Check counterexample intersection
                example_intersection = [neg_example[i] & counterexample[i] for i in range(len(binary_parser.V))]
                if tuple(example_intersection) not in counterexample_intersections and \
                    sum(example_intersection) > 0 and (example_intersection != neg_example):
                    
                    num_queries['MQ'] += 1
                    reply = ask_membership_oracle(example_intersection)
                    if type(reply) is str: 
                        print(metadata)
                        metadata['errors']['membership'] += 1
                    elif (not reply) and models(example_intersection, Q, binary_parser.V):

                        # Replace with smaller negative counterexample
                        negative_counterexamples[negative_counterexamples.index(neg_example)] = example_intersection
                        example_replaced = True
                        break
                
                counterexample_intersections.add(tuple(example_intersection))

            if not example_replaced: 
                negative_counterexamples.append(counterexample)

        # Non-Horn check
        for neg_counterexample in negative_counterexamples:
            positive_superset = [pos_counterexample for pos_counterexample in positive_counterexamples if is_subset(neg_counterexample, pos_counterexample)]
            
            if positive_superset == []: 
                continue
            
            if neg_counterexample == intersection_of_lists(positive_superset):
                negative_counterexamples.remove(neg_counterexample)
                non_horn_counterexamples.append(neg_counterexample)
        
        # Reconstruct H
        H = binary_parser.background.copy()
        for neg_counterexample in negative_counterexamples:
            positive_superset = [pos_counterexample for pos_counterexample in positive_counterexamples if is_subset(neg_counterexample, pos_counterexample)]
            antecedent = And(*[binary_parser.V[i] for i, val in enumerate(neg_counterexample) if val == 1])
            
            if len(positive_superset) == 0: consequent = False
            else: consequent = And(*[binary_parser.V[i] for i, val in enumerate(intersection_of_lists(positive_superset)) if val == 1])

            H.add(Implies(antecedent, consequent))

        # Reconstruct Q
        Q = set()
        for nh_counterexample in non_horn_counterexamples:
            antecedent = And(*[binary_parser.V[i] for i, val in enumerate(nh_counterexample) if val == 1])
            consequent = Or(*[binary_parser.V[i] for i, val in enumerate(nh_counterexample) if val == 0]) 

            Q.add(Implies(antecedent, consequent))
        
        if args.verbose:
            signed_counterexample = '+' if is_positive else '-'
            print(f'\nIteration: {abs(iteration)}\n\n' + 
                  f'({sample_number+1}) Counterexample: ({signed_counterexample}) {binary_parser.binary_to_sentence(counterexample)}\n\n'+
                  f'New Hypothesis H: {sorted([eq_parser.parse(h) for h in H if h not in binary_parser.background])}\n\n' +
                  f'New Hypothesis Q: {[q for q in Q]}\n\n' +
                  f'New Hypothesis length: {len(H)+len(Q)-len(binary_parser.background)} + background: {len(binary_parser.background)}\n\n' +
                  f'total positive counterexamples:  {len(positive_counterexamples)}\n' +
                  f'total negative counterexamples:  {len(negative_counterexamples)} ({sum(sum(lst) for lst in negative_counterexamples)})\n' +
                  f'total non-horn counterexamples:  {len(non_horn_counterexamples)}\n\n' +
                  f'errors:  {metadata["errors"]}\n\n\n\n')
        
        iteration+=1
        stop = timeit.default_timer()
        metadata['iteration_time'].append(stop-start)
        
    metadata['queries'] = num_queries
    metadata['terminated'] = metadata['terminated'] or (iteration != num_iterations)  
    
    return (H, Q, metadata)