import os
import random
from typing import Tuple
from jpype import JArray, JInt
from pipelines.prompta.utils import load_dfa, query2str, show_dfa
from prompta.core.language.rule_set.tasks import *
from prompta.utils.java_libs import Automata, DefaultQuery, DFAs, Word, WordBuilder, String, JavaClass
from .base_language import BaseLanguage


random.seed(0)


pattern = r'[\w\s&|!]+(?=;|;cycle\{|})'
def counterexample_by_automata(target, hypothesis, inputs, length=0):
    if length > 0:
        return counterexample_by_automata_ge_len(target, hypothesis, inputs, length)
    sep = Automata.findSeparatingWord(target, hypothesis, inputs)
    if sep:
        counterexample = DefaultQuery(sep, target.computeOutput(sep))
        # print("Counterexample found:", counterexample)
        return counterexample
    return None


def counterexample_by_automata_ge_len(model, hypothesis, alphabet, length):
    if Automata.findSeparatingWord(model, hypothesis, alphabet) is None:
        return None

    inf = int(2 ** 16 - 1)

    negHyp = DFAs.complement(hypothesis, alphabet)
    ceDFA = DFAs.and_(model, negHyp, alphabet)
    ceDFA = Automata.invasiveMinimize(ceDFA, alphabet)

    if ceDFA.size() == 1 and not ceDFA.isAccepting(ceDFA.getInitialState()):
        ceDFA = DFAs.xor(model, hypothesis, alphabet)
        ceDFA = Automata.invasiveMinimize(ceDFA, alphabet)
    
    size = ceDFA.size()
    accDists = [inf for _ in range(size)]

    for i in range(size):
        i = JInt(i)
        if ceDFA.isAccepting(i):
            accDists[i] = 0

    while True:
        stable = True
        for state in range(size):
            succMinDist = inf
            for sym in alphabet:
                trans = ceDFA.getTransition(JInt(state), JInt(alphabet.getSymbolIndex(sym)))
                succ = ceDFA.getIntSuccessor(trans)
                succMinDist = min(succMinDist, accDists[succ])
            if succMinDist == inf:
                continue
            succMinDist += 1
            if succMinDist < accDists[state]:
                accDists[state] = succMinDist
                stable = False
        if stable:
            break
		
    remaining = length
    if accDists[ceDFA.getIntInitialState()] > remaining:
        remaining = accDists[ceDFA.getIntInitialState()]
    
    currState = ceDFA.getIntInitialState()
    traceBuilder = WordBuilder(remaining)
    while remaining > 0:
        remaining -= 1
        candidates = []
        for sym in alphabet:
            trans = ceDFA.getTransition(JInt(currState), alphabet.getSymbolIndex(sym))
            succ = ceDFA.getIntSuccessor(trans)
            if accDists[succ] <= remaining:
                candidates.append(sym)
        if len(candidates) == 0:
            if ceDFA.isAccepting(currState):
                break
            return counterexample_by_automata(model, hypothesis, alphabet)
        
        while True:
            symIdx = random.randint(0, len(candidates) - 1)
            sym = candidates[symIdx]
            traceBuilder.add(sym)
            trans = ceDFA.getTransition(JInt(currState), alphabet.getSymbolIndex(sym))
            tmpState = ceDFA.getIntSuccessor(trans)
            if tmpState < 0:
                candidates.pop(symIdx)
                if len(candidates) == 0:
                    return counterexample_by_automata(model, hypothesis, alphabet)
            else:
                currState = tmpState
                break

    if not ceDFA.isAccepting(currState):
        raise AssertionError
    trace = traceBuilder.toWord()
    ce = DefaultQuery(trace, model.accepts(trace))
    # print("Counterexample found:", ce)
    return ce


class RuleBasedLanguage(BaseLanguage):
    def __init__(self, target_path: str, alphabet: List[str]=None, ce_len=0) -> None:
        super().__init__(alphabet)
        ctx_name = target_path.split(os.path.sep)[-1].split('.')[0]
        self.ctx_cls = eval(ctx_name)
        if alphabet is not None:
            self.ctx = self.ctx_cls(alphabet)
        else:
            self.ctx = self.ctx_cls()
        
        if os.path.exists(target_path):
            target = load_dfa(target_path)
            self.target_aut = target.model
            self.target_alphabet = target.alphabet
            print(self.target_alphabet)
        else:
            print("Assume using expert oracle")
        self.ce_len = ce_len
    
    def in_language(self, input_str: Tuple[str]):
        return self.ctx.in_language(input_str)
    
    def counterexample(self, aut, _type=str):
        ce = counterexample_by_automata(target=self.target_aut, hypothesis=aut, inputs=self.target_alphabet, length=self.ce_len)
        if _type == str:
            return query2str(ce)
        return ce

    
