# Implementation of various curiosity signals for pre-training.
# Use pretraining.py or proofsearch.py to run jobs with these.curiosity signals.

from typing import Optional
import collections
import unittest
import random


class CuriositySignal:
    def __init__(self, _cfg, pi):
        self._pi = pi

    def reward(self, state):
        raise NotImplementedError

    def get_training_examples(self) -> list[str]:
        return []


class Constant(CuriositySignal):
    '''Uniform (constant) curiosity signal. Used for ablating curiosity.'''
    def reward(self, _state):
        return 1


class InverseFrequency(CuriositySignal):
    '''r(p) = 1/f where f is how many times p has been proven before.'''
    def __init__(self, cfg, pi):
        super().__init__(cfg, pi)
        self._counts = collections.Counter()

    def reward(self, node):
        s = node.last_construction_dtype()

        if s:
            self._counts[s] += 1
            return 1 / self._counts[s]

        return 0


class ConstructionLikelihood(CuriositySignal):
    '''Likelihood of the proved fact'''
    def __init__(self, cfg, pi):
        super().__init__(cfg, pi)
        self._online = cfg.get('online', True)
        self._aggregation = cfg.get('aggregation', 'sum')

    def reward(self, node):
        s = node.last_construction_dtype()
        if not s:
            return 0

        return -self._pi._lm.goal_logprob('<>', s, aggregation=self._aggregation, step=self._online)


class TransitionLikelihood(CuriositySignal):
    '''Likelihood of the proved fact'''
    def __init__(self, cfg, pi):
        super().__init__(cfg, pi)
        self._online = cfg.get('online', True)
        self._aggregation = cfg.get('aggregation', 'sum')
        self._batch_size = cfg.get('replay_batch_size', 4)
        self._memory = []

    def reward(self, node):
        if not node._proof_states:
            return 0

        ps = node._proof_states[0]
        action, result = self._get_transition_str(ps)

        if action and result:
            preamble = f'A:{action}=>'
            completion = result

            mem = random.sample(self._memory,
                                k=min(len(self._memory), self._batch_size))

            preambles = [r for r, _ in mem] + [preamble]
            completions = [c for _, c in mem] + [completion]

            self._memory.append((preamble, completion))

            logprob = self._pi._lm.completion_logprob(preambles, completions, True)[-1]
            return -logprob
        return 0

    def get_training_examples(self) -> list[str]:
        return [preamble + completion for preamble, completion in self._memory]

    @staticmethod
    def _get_transition_str(state) -> (Optional[str], Optional[str]):
        name = state.construction_from_last_action()

        action, result = None, None

        if name:
            dtype = state.lookup(name).get_type()
            args = state.generating_arguments(name)

            if args is None:
                return None, None

            for i in range(1, len(args)):
                if args[0] == 'rewrite' and i == 2:
                    # Properly format location expression for specifying rewrite
                    prop, loc = args[i].split('@', 1)
                    args[i] = f'{state.format_object(prop)}@{loc}'
                else:
                    args[i] = state.format_object(args[i])

            action = '(' + ' '.join(args) + ')'
            result = dtype

        return action, result


class TestTransitionLikelihood(unittest.TestCase):
    def test_action_string(self):
        import problems
        import proofsearch

        pset = problems.load_problemset('nng')
        st = pset.initialize_problem('m_mul_one')

        # Find intro.
        intro = [a for a in st.actions() if a.is_intro()][0]
        st = st.execute_action(intro)[0]

        # Find show with *_s.
        a = [a for a in st.actions() if a.is_construct() and '*_s' in str(a)][0]
        st = st.execute_action(a)[0]

        a = [a for a in st.actions() if '(= (* x (s z)) (+ x (* x z)))' in str(a)][0]
        st = st.execute_action(a)[0]

        a = [a for a in st.actions() if a.is_construct() and '*_z' in str(a)][0]
        st = st.execute_action(a)[0]

        a = [a for a in st.actions() if '(= (* x z) z)' in str(a)][0]
        st = st.execute_action(a)[0]

        a = [a for a in st.actions() if a.is_construct() and 'rewrite' in str(a)][0]
        st = st.execute_action(a)[0]

        print(st.actions())

        a = [a for a in st.actions() if '(= (* x (s z)) (+ x z))' in str(a)][0]
        st = st.execute_action(a)[0]

        a, b = TransitionLikelihood._get_transition_str(st)

        self.assertEqual(a, '(rewrite (= (* x z) z) (= (* x (s z)) (+ x (* x z)))@type@2@2)')
        self.assertEqual(b, '(= (* x (s z)) (+ x z))')


class ConstructionLikelihoodGradientNorm(CuriositySignal):
    def __init__(self, cfg, pi):
        super().__init__(cfg, pi)

    def reward(self, node):
        s = node.last_construction_dtype()
        if not s:
            return 0

        return self._pi._lm.gradient_step([self._pi._lm.format_provable_goal('<>', s)],
                                          return_norm=True)


def make_curiosity_signal(cfg, pi=None):
    if cfg.type == 'constant':
        return Constant(cfg, pi)

    if cfg.type == 'inverse-frequency':
        return InverseFrequency(cfg, pi)

    if cfg.type == 'logprob':
        return ConstructionLikelihood(cfg, pi)

    if cfg.type == 'transition-logprob':
        return TransitionLikelihood(cfg, pi)

    if cfg.type == 'grad-norm':
        return ConstructionLikelihoodGradientNorm(cfg, pi)

    # if cfg.type == 'path-logprob':
    #    return PathLikelihood(cfg, pi)

    raise NotImplementedError(cfg.type)
