from tqdm import tqdm
import numpy as np
from collections import deque

class Atom: 
    """Atom: satisfied when the given atom is in the set of labels"""

    def __init__(self, atom):
        self.atom = atom

    def sat(self, labels):
        return self.atom in labels

class Truth: 
    """Truth: always satisfied"""

    def __init__(self):
        pass

    def sat(self, labels):
        return True

class And:
    """And: satisfied when both subformulae are satisfied"""

    def __init__(self, subformula_1, subformula_2):
        self.subformula_1 = subformula_1
        self.subformula_2 = subformula_2

    def sat(self, labels):
        return self.subformula_1.sat(labels) and self.subformula_2.sat(labels)

class Or:
    """Or: satisfied when either subformulae are satisfied"""

    def __init__(self, subformula_1, subformula_2):
        self.subformula_1 = subformula_1
        self.subformula_2 = subformula_2

    def sat(self, labels):
        return self.subformula_1.sat(labels) or self.subformula_2.sat(labels)

class Neg:
    """Negation: satisfied when the subformula is not satisfied"""
    
    def __init__(self, subformula):
        self.subformula = subformula
        
    def sat(self, labels):
        return not self.subformula.sat(labels)

class Implies:
    """Implies: satisfied when subformula_2 is satisified if subformula_1 is satisfied"""

    def __init__(self, subformula_1, subformula_2):
        self.subformula_1 = subformula_1
        self.subformula_2 = subformula_2

    def sat(self, labels):
        return Or(Neg(self.subformula_1), self.subformula_2).sat(labels)

class DFA:

    """Implements a deterministic finite automata (DFA), 
       where state transitions are governed by propositional formula

    Input attributes:
        states: list of automata states
        initial: the initial state
        accepting: list of accepting states

    Other attributes:
        edges: dictionary of state to state transitions for each state
        state: current state of the DFA during execution
    """

    def __init__(self, states, initial, accepting):

        assert type(states) is list
        assert initial in states
        assert type(accepting) is list
        self.states = states
        self.initial = initial
        self.accepting = accepting
        self.edges = {s : {} for s in self.states}
        self.reset()

    def add_edge(self, parent, child, condition):
        """adds an edge from parent to child"""
        self.edges[parent][child] = condition

    def reset(self):
        """resets the DFA to the initial state"""
        self.state = self.initial
        return self.state

    def has_edge(self, state_1, state_2):
        """check if there is an edge from state_1 to state_2"""
        try: 
            x = self.edges[state_1][state_2]
            return True
        except KeyError:
            return False

    def check(self, trace):
        """check if a given trace is accepted on the DFA"""
        state = self.initial
        for labels in trace:
            state = self.transition(state, labels)
        return state in self.accepting

    def transition(self, state, labels):
        """compute the next state from a given state and set of labels"""
        for next_state in self.edges[state].keys():
            if self.edges[state][next_state].sat(labels):
                return next_state
        return state

    def step(self, labels):
        """evolve the DFA one step for a given set of labels"""
        next_state = self.transition(self.state, labels)
        self.state = next_state
        return self.state in self.accepting, self.state

class Cost_Function:

    """
    The cost function with respect to a given DFA

    Input attributes:
        dfa: the deterministic finite automata (DFA)
    """

    def __init__(self, dfa, reward_shaping='none', discount=0.95):
        self.dfa = dfa
        assert reward_shaping in ['none', 'potential', 'cycle']
        self.reward_shaping = reward_shaping
        self.discount = discount

        if self.reward_shaping == 'none':
            pass
        elif self.reward_shaping == 'potential':
            # value iteration
            self.vi_steps = 1000
            self.V = {u: 0.0 for u in self.dfa.states}
            gamma = 0.9
            assert gamma < self.discount

            print("reward shaping DFA ...")
            for i in tqdm(range(self.vi_steps)):
                diff = 0.0
                for u in self.dfa.states:
                    if u in self.dfa.accepting:
                        v_u = self.V[u]
                        self.V[u] = 1.0/(1.0 - gamma)
                    else:
                        v_u = self.V[u]
                        self.V[u] = np.max([gamma * self.V[v] for v in self.dfa.edges[u].keys()])
                    diff = max(diff, np.abs(v_u - self.V[u]))
                if diff < 1e-8:
                    break

        elif self.reward_shaping == 'cycle':
            # for each state find the mimimal path to the initial state
            self.dist_to_initial = {u: np.inf for u in self.dfa.states}

            u_init = self.dfa.initial

            edges = self.dfa.edges.copy()
            edges.update({u: {} for u in self.dfa.accepting})

            for u in self.dfa.states:

                dist = {v: np.inf for v in self.dfa.states}
                dist[u] = 0.0

                queue = deque([u])

                while queue:
                    current = queue.popleft()

                    for w in edges.get(current, []):
                        if w == u_init:
                            self.dist_to_initial[u] = dist[current] + 1
                            queue = []
                            break
                        elif dist[w] == np.inf:
                            dist[w] = dist[current] + 1
                            queue.append(w)

            # set the initial state distance to 0.0
            self.dist_to_initial[u_init] = 0.0
        else:
            raise NotImplementedError(self.reward_shaping)

    def reset(self):
        """reset the DFA"""
        return self.dfa.reset()

    def potential(self, automaton_state, next_automaton_state):
        if self.reward_shaping == 'none':
            return 0.0
        elif self.reward_shaping == 'potential':
            if (automaton_state in self.dfa.accepting):
                return (self.discount * self.V[automaton_state] - self.V[automaton_state])
            else:
                return (self.discount * self.V[next_automaton_state] - self.V[automaton_state])
        elif self.reward_shaping == 'cycle':
            if (automaton_state in self.dfa.accepting) or (next_automaton_state in self.dfa.accepting):
                return 0.0
            else:
                return self.dist_to_initial[next_automaton_state] - self.dist_to_initial[automaton_state]
        else:
            raise NotImplementedError(self.reward_shaping)

    def step(self, labels):
        """evolve the DFA one step for a given set of labels"""
        """return the cost and next state of the DFA"""
        if self.reward_shaping:
            automaton_state = self.dfa.state
            accepting, next_automaton_state = self.dfa.step(labels)
            return accepting, float(accepting) + self.potential(automaton_state, next_automaton_state), next_automaton_state
        else:
            accepting, next_automaton_state = self.dfa.step(labels)
            return accepting, float(accepting), next_automaton_state













        