import random
from typing import Dict
import jpype
from jpype import JImplements, JOverride
from pipelines.prompta.learner.java_utils.dfa import TTTLearnerDFA
from pipelines.prompta.oracle import BaseOracle
from pipelines.prompta.utils import save_dfa, tuple2word, word2tuple
from prompta.utils.java_libs import AcexAnalyzers, DefaultQuery, Word
from .base_learner import BaseLearner
from .l_star import LStarLearner
from .ttt import TTTLearner
from .rpni_edsm import RPNI_EDSM


def sample(data, ratio):
    random.seed(0)
    random.shuffle(data)
    return data[:int(len(data) * ratio)], data[int(len(data) * ratio):]

class LAPRLearner(BaseLearner):
    ID = "LAPRLearner"
    def __init__(self, oracle: BaseOracle, exp_dir: str):
        super().__init__(oracle, exp_dir)
        self.core_learner = TTTLearner(oracle, exp_dir).learner
        self.oracle = oracle
        self.exp_dir = exp_dir
        self.verificator = RPNI_EDSM(self.oracle, exp_dir)
        self.positive_examples = set()
        self.negative_examples = set()
        self.active_hypothesis = None
        self.passive_hypothesis = None
        self.epsilon = 0.2
        self.last_conflict_ce = None
        self.reset()

    def reset(self):
        self.core_learner = TTTLearner(self.oracle, self.exp_dir).learner
        self.conflict_detected = False

    def learn(self):
        self.core_learner.startLearning()
        while True:
            self.active_hypothesis = self.core_learner.getHypothesisModel()
            ce = self.check_positive_conjecture()
            if ce is None:
                final_hypothesis = self.active_hypothesis
                break
            self.update_cache(ce)

            try:
                self.core_learner.refineHypothesis(ce)
            except:
                ce_word = word2tuple(ce.getInput())
                if ce_word in self.positive_examples or ce_word in self.negative_examples or self.oracle.me_counter > 45:
                    self.conflict_detected = True
                    self.learn_passive_hypothesis()
                    pce = self.check_passive_conjecture()
                    if pce is None:
                        final_hypothesis = self.passive_hypothesis
                        break
                    self.update_cache(pce)
                self.access_sequence_refine(ce)
                self.reset()
                self.core_learner.startLearning()
        save_dfa(self.get_dfa_save_path(), final_hypothesis, self.oracle.jalphabet)
        
        return final_hypothesis
    
    def check_positive_conjecture(self):
        for ce in self.positive_examples:
            word = tuple2word(ce)
            if not self.active_hypothesis.computeOutput(word):
                return DefaultQuery(word, True)
        for ce in self.negative_examples:
            word = tuple2word(ce)
            if self.active_hypothesis.computeOutput(word):
                return DefaultQuery(word, False)
        ce = self.oracle.check_conjecture(self.active_hypothesis, 'DefaultQuery')
        self.conflict_detected = False
        return ce
    
    def check_passive_conjecture(self):
        ce = self.oracle.check_conjecture(self.passive_hypothesis, 'DefaultQuery')
        self.conflict_detected = False
        return ce
        
    def update_cache(self, ce):
        ce_word = word2tuple(ce.getInput())
        
        # Add the original ce to the cache
        self.oracle.membership_query_cache[ce_word] = ce.getOutput()

        # Updata postive and negative examples
        if ce.getOutput():
            self.positive_examples.add(ce_word)
        else:
            self.negative_examples.add(ce_word)

        # print(f"Statistics: {self.get_mistake_stat(self.oracle.membership_query_cache)}")

    def access_sequence_refine(self, ce):
        ce_word = word2tuple(ce.getInput())
        related_queries = self.oracle.ce2queries[ce_word]
        related_query_set = set(i for i, o in related_queries)
        labeled_positive_examples, labeled_negative_examples = set(self.positive_examples), set(self.negative_examples)
        raw_positive_examples, raw_negative_examples = set(), set()
        for k, v in self.oracle.membership_query_cache.items():
            if k in labeled_positive_examples or k in labeled_negative_examples or k in related_query_set:
                continue
            if v:
                raw_positive_examples.add(k)
            else:
                raw_negative_examples.add(k)

        positive_train_set, _ = sample(list(raw_positive_examples), len(self.positive_examples) / (len(raw_positive_examples) + 1e-10))
        negative_train_set, _ = sample(list(raw_negative_examples), len(self.negative_examples) / (len(raw_negative_examples) + 1e-10))

        positive_train_set += labeled_positive_examples
        negative_train_set += labeled_negative_examples
        self.learn_passive_hypothesis(positive_train_set, negative_train_set)

        for input_word, output in related_queries:
            self.oracle.membership_query_cache[input_word] = self.passive_hypothesis.computeOutput(input_word)

    def learn_passive_hypothesis(self, positive_train_set=None, negative_train_set=None):
        self.verificator.reset()
        if self.conflict_detected:
            self.verificator.add_examples(self.positive_examples, True)
            self.verificator.add_examples(self.negative_examples, False)
            self.passive_hypothesis = self.verificator.learn()
            num_ce = len(self.positive_examples) + len(self.negative_examples)
            if (num_ce > 50 and self.predict_epsilon() > self.epsilon) or num_ce > 200 or len(self.oracle.membership_query_cache) > 2000:
                self.passive_hypothesis = self.verificator.learn(active=True)
                return
        if positive_train_set is not None:
            self.verificator.add_examples(positive_train_set, True)
        else:
            self.verificator.add_examples(self.positive_examples, True)
        if negative_train_set is not None:
            self.verificator.add_examples(negative_train_set, False)
        else:
            self.verificator.add_examples(self.negative_examples, False)
        
        self.passive_hypothesis = self.verificator.learn()

    def predict_epsilon(self):
        num_errors = 0
        for q, v in self.oracle.membership_query_cache.items():
            if v != self.passive_hypothesis.computeOutput(tuple2word(q)):
                num_errors += 1
        return num_errors / len(self.oracle.membership_query_cache)

    def get_mistake_stat(self, new_query_cache: Dict):
        count = 0
        for k, v in self.oracle.mistake_cache.items():
            if self.oracle.mistake_cache[k] == new_query_cache[k]:
                count += 1
        return {'Number mistakes': count, 'Number queries': len(new_query_cache), 'Query Acc': (len(new_query_cache) - count) / len(new_query_cache), 'Number counter examples': len(self.positive_examples) + len(self.negative_examples)}

    def get_access_sequence(self, word):
        initial_ttt_state = self.core_learner.getAnyState(Word.epsilon())
        ttt_state = self.core_learner.getDeterministicState(initial_ttt_state, word)
        return ttt_state.getAccessSequence()
