import os
from urllib.request import urlopen
from tempfile import NamedTemporaryFile
from shutil import unpack_archive
import subprocess
import numpy as np
import networkx as nx
import numpy as np
import itertools
CACHE_RABINIZER = True
RABINIZER_PATH = os.environ['RABINIZER_PATH']
CORRECTIONS = {

    # variable transformations to obtain equivalent LDBA, hopefully better suited for optimization

    'GF(a&XF(b&XFc))&G!((a&b)|(a&c)|(b&c)|z)': {'0': '1', '1': '2', '2': '0'},
    'GF(a&XF(b&XF(c&XFd)))&G!((a&b)|(a&c)|(a&d)|(b&c)|(b&d)|(c&d)|z)': {'0': '1', '1': '0'},

    'GF(a&XF(b&XFc))&G!(z)': {'0': '1', '1': '2', '2': '0'},
    'GF(a&XF(b&XF(c&XFd)))&G!(z)': {'0': '1', '1': '3', '2': '0', '3': '2', },

    'GF(a&XF(b&XFc))': {'0': '2', '1': '0', '2': '1'},
    'GF(a&XF(b&XF(c&XFd)))': {'0': '1', '1': '3', '2': '0', '3': '2'},
}


def ltl2ldba(formula):
    ZIPURL = 'https://www7.in.tum.de/~kretinsk/rabinizer4.zip'

    if os.path.exists(RABINIZER_PATH+f'/{formula}.bck'):
        with open(RABINIZER_PATH+f'/{formula}.bck', 'r') as reader:
            return str(reader.read())

    if CACHE_RABINIZER and os.path.isfile(RABINIZER_PATH+'/rabinizer4/bin/ltl2ldba'):
        subprocess.run(['chmod', '700', RABINIZER_PATH+'/rabinizer4/bin/ltl2ldba'])
        output = subprocess.check_output([RABINIZER_PATH+'/rabinizer4/bin/ltl2ldba', '-d', '-e', formula])
        with open(RABINIZER_PATH+f'/{formula}.bck', 'w') as writer:
            writer.write(str(output))
        return str(output)

    with urlopen(ZIPURL) as zipresp, NamedTemporaryFile() as tfile:
        print('Downloading rabinizer.')
        tfile.write(zipresp.read())
        tfile.seek(0)
        unpack_archive(tfile.name, RABINIZER_PATH, format = 'zip')
        subprocess.run(['chmod', '700', RABINIZER_PATH+'/rabinizer4/bin/ltl2ldba'])
        output = subprocess.check_output([RABINIZER_PATH+'/rabinizer4/bin/ltl2ldba', '-d', '-e', formula])
        return str(output)


def get_condition(cond):
    parsed_cond = []
    for e in cond.split(' '):
        number = ''.join([digit for digit in e if digit.isdigit()])
        if len(number):
            e = e.replace(number, 'x['+number+']')
        parsed_cond.append(e)
    parsed_cond = ' '.join(parsed_cond)
    parsed_cond = parsed_cond.replace('t', 'True')
    parsed_cond = parsed_cond.replace('!', 'not ')
    parsed_cond = parsed_cond.replace('&', 'and')
    parsed_cond = parsed_cond.replace('|', 'or')
    condition_ = lambda x: eval(parsed_cond)
    return condition_


def make_automaton(formula):
    n_eps = 0
    eps = {}
    lines = ltl2ldba(formula).split('\\n')
    start_node = int(lines[2].split(' ')[-1])
    accepting = [int(l[4:-1]) for l in lines[4].split(' ')[2:]]
    aps = [l.strip('"') for l in lines[6].split(' ')[2:]]
    graph = nx.MultiDiGraph()
    from_node = None
    for line in lines[8:-2]:
        if line.startswith('State'):
            from_node = int(line.split(' ')[-1])
            if not graph.has_node(from_node):
                graph.add_node(from_node)
            jump_count = 0
        else:
            if not line.startswith('['):  # skip transition
                to_node = int(line)
                condition_as_str = 'False'
                jump_count += 1
                jump_id = jump_count

                if from_node not in eps:
                    eps[from_node] = []
                eps[from_node].append(to_node)
                n_eps = n_eps + 1
            else:
                to_node = int(line.split(']')[-1].split(' ')[1])
                condition_as_str = line.split(']')[0][1:]
                jump_id = 0
            
            if formula in CORRECTIONS:
                corr = CORRECTIONS[formula]
                condition_as_str =  ''.join(c if c not in corr else corr[c] for c in condition_as_str)

            accepting_edge = False
            if line.endswith('}'):
                edge_label = int(line.split('{')[-1][:-1])
                if edge_label in accepting:
                    accepting_edge = True
            if not graph.has_node(to_node):
                graph.add_node(to_node)
            graph.add_edge(
                from_node,
                to_node,
                accepting=accepting_edge,
                condition_as_str=condition_as_str,
                condition=get_condition(condition_as_str),
                jump_id=jump_id
            )

    # switch from accepting edges to accepting nodes
    nx.set_node_attributes(graph, {n: {'accepting': False} for n in graph.nodes()})
    for v in tuple(graph.nodes()):
        new_edges, to_remove = [], []
        for u, _, k, d in graph.in_edges(v, keys=True, data=True):
            if d['accepting']:
                to_remove.append((u,v,k))
                new_edges.append(dict(
                    u_for_edge=u,
                    v_for_edge=graph.order(),
                    accepting=False,
                    condition_as_str=d['condition_as_str'],
                    condition=d['condition'],
                    jump_id=d['jump_id']
                ))
        if new_edges:
            graph.add_node(graph.order(), accepting=True)
            graph.remove_edges_from(to_remove)
            for e in new_edges: graph.add_edge(**e)
            for _, v2, _, d2 in tuple(graph.out_edges(v, keys=True, data=True)):
                graph.add_edge(
                    u_for_edge=graph.order()-1,
                    v_for_edge=v2 if not (v==v2 and d2['accepting']) else graph.order()-1,
                    accepting=d2['accepting'],
                    condition_as_str=d2['condition_as_str'],
                    condition=d2['condition'],
                    jump_id=d2['jump_id']
                )

    for _, _, _, d in graph.edges(data=True, keys=True):
        del d['accepting']

    # remove orphaned nodes
    keepers = tuple(nx.dfs_postorder_nodes(graph,source=0))
    graph = graph.subgraph(keepers)
    graph = nx.relabel_nodes(graph, {e: i for i, e in enumerate(sorted(graph.nodes))})

    return LDBAutomaton(graph, start_node, aps, eps)


class LDBAutomaton:

    def __init__(self, graph, start_state, aps, eps):
        self.graph, self.start_state, self.aps, self.eps = graph, start_state, aps, eps
        self.make_jump_mask()

    def step(self, ap, jump_id=0):

        if self.curr_state == -1:
            return self.curr_state, False
        
        if jump_id > 0:
            for _, v, d in self.graph.out_edges(self.curr_state, data=True):
                if d['jump_id'] == jump_id:
                    self.curr_state = v
                    return self.get_obs(), False
            return self.get_obs(), False
 
        destination, accepting = [], False
        for _, v, d in self.graph.out_edges(self.curr_state, data=True):
            if d['condition'](ap):
                destination.append(v)
                accepting = self.graph.nodes[v]['accepting']

        assert len(destination) < 2
        if not len(destination):
            self.curr_state = -1
        else:
            self.curr_state = destination[0]

        return self.get_obs(), accepting
    
    def epsilon_step(self, eps_action):
        try:
            self.curr_state = self.eps[self.curr_state][eps_action]
        except:
            assert 'This epsilon step doesnt exist, (q,e) = (%s, %s)' % (self.curr_state, eps_action)
        return self.get_obs()

    def reset(self, ap):
        self.curr_state = self.start_state
        return self.step(ap, jump_id=0)

    def get_obs(self):
        return self.curr_state

    def get_n_states(self):
        return self.graph.order() + 1
    
    def make_jump_mask(self):
        max_jumps = max([len([1 for _, _, d in self.graph.out_edges(n, data=True) if d['jump_id'] > 0]) for n in self.graph])
        self.jump_mask = np.zeros((self.get_n_states(), 1+max_jumps))
        self.jump_map = np.zeros((self.get_n_states(), 1+max_jumps), dtype=int) + self.graph.order() + 1
        self.transition_map = np.zeros((self.get_n_states(), 2**len(self.aps)), dtype=int) + self.graph.order()
        self.jump_mask[:, 0] = 1
        self.jump_map[:, 0] = np.arange(self.get_n_states())
        for u, v, d in self.graph.edges(data=True):
            if d['jump_id'] > 0:
                self.jump_mask[u][d['jump_id']] = 1
                self.jump_mask[u][d['jump_id']] = v
            else:
                for word in itertools.product([0, 1], repeat=len(self.aps)):
                    idx = (u, (word*(2**np.arange(len(word)))).sum())
                    self.transition_map[idx] = v if d['condition'](word) else self.transition_map[idx]
        self.accepting_nodes = np.array([d['accepting'] for v, d in sorted(self.graph.nodes(data=True))]+[False])

    def get_jump_mask(self):
        return self.jump_mask

    def get_n_jumps(self):
        return self.jump_mask.shape[-1]

    def get_graph(self):
        return self.graph
