# Behaviour policy in order to do off-policy RL
# i.e. when the behaviour policy is different to the target policy.

from .equation import Equation, optimise_eq_consts
import copy
import numpy as np
import math
import scipy
import random
from enum import Enum, auto


class BPStrategy(Enum):
    target = auto()
    enumerate_all = auto()
    equal_prob_tokens = auto()


class BehaviourPolicy:

    def __init__(self, name, target_policy, token_set, max_num_tokens,
                 net_masks, all_eqs=None):

        if name not in BPStrategy.__members__.keys():
            raise KeyError(f'{name} not a behaviour policy strategy')

        self._strategy = BPStrategy[name]

        # All enumerated equations need to be provided for the 'enumerate_all'
        # behaviour policy
        if name == 'enumerate_all' and all_eqs is None:
            raise ValueError('Must provided enumerate equations in order to '
                             'use the enumerate_all behaviour strategy')

        self._all_eqs = all_eqs

        # Determine all valid expressions
        if self._all_eqs is not None:
            self._all_valid_eqs = [e for e in self._all_eqs
                                   if e.valid_eq(max_num_tokens, net_masks)]

        # Enumerate all policy cannot be used if there are distr_consts in the
        # token set
        if sum(1 for t in token_set if t['op'] == 'distr_const') > 0:
            if name == 'enumerate_all':
                raise ValueError(
                    'Enumerate all policy cannot be used if there '
                    'are distr_const in the token set'
                )

        self._target_policy = target_policy

        self._token_set = token_set
        self._max_num_tokens = max_num_tokens
        self._net_masks = net_masks

        # Precompute and store number of constants and unary operators
        self._num_consts = sum(1 if t['type'] == 'const' else 0
                               for t in token_set)
        self._num_unary_ops = sum(1 if t['type'] == 'un_op' else 0
                                  for t in token_set)

    def sample(self):

        # Sample according to the particular behavioural policy
        if self._strategy == BPStrategy.target:
            return self._target_policy.sample()
        elif self._strategy == BPStrategy.enumerate_all:
            return self._sample_enum_all()
        elif self._strategy == BPStrategy.equal_prob_tokens:
            return self._sample_eq_prob_tokens()

    def sample_and_optimise(self, data, log_likelihood_func):
        return optimise_eq_consts(self.sample(), data, log_likelihood_func,
                                  self._max_num_tokens, self._net_masks)

    # Determine the pdf of a particular equation under the behavioural policy
    def pdf(self, z):

        if self._strategy == BPStrategy.target:
            return self._target_policy.pdf(z)
        elif self._strategy == BPStrategy.enumerate_all:
            return self._pdf_enum_all(z)
        elif self._strategy == BPStrategy.equal_prob_tokens:
            return self._pdf_eq_prob_tokens(z)

    # Determine importance weight for a particular equation
    def importance_weight(self, z):

        # If target policy is used as the behavioural policy importance weight
        # is 1.0
        if self._strategy == BPStrategy.target:
            return 1.0

        return (self._target_policy.pdf(z) / self.pdf(z)).item()

    # Sample from behavioural policy that assigns an equal probability to all
    # tokens for each step of the equation building
    def _sample_eq_prob_tokens(self):

        tokens = []
        num_consts_required = 1

        while num_consts_required > 0:

            # Determine parameters of categorical distribution
            p, pre_softmax_mask = self._determine_p(num_consts_required,
                                                    tokens)

            # Sample token
            token = copy.deepcopy(np.random.choice(self._token_set, 1, p=p)[0])

            token['pre_softmax_mask'] = copy.deepcopy(pre_softmax_mask)

            tokens.append(token)

            # Increase or decrease the number of constants required
            # depending on the sample token type
            if token['type'] == 'bin_op':
                num_consts_required += 1
            elif token['type'] == 'const':
                num_consts_required -= 1

        eq = Equation(tokens)

        # Sample constant values for constants with distributions
        consts_params = self._target_policy.get_consts_params(eq)
        consts = [scipy.stats.norm.rvs(p[0], p[1]) for p in consts_params]

        eq.set_distr_consts(consts)

        return eq

    # Samples uniformly according to all enumerated equations
    def _sample_enum_all(self):

        z = copy.deepcopy(random.choice(self._all_valid_eqs))

        if z.num_distr_consts() > 0:

            # Sample distributional constants from target policy
            consts_params = self._target_policy.get_consts_params(z)
            consts = [scipy.stats.norm.rvs(p[0], p[1]) for p in consts_params]

            z.set_distr_consts(consts)

        return z

    # Determine pdf according to the behaviour policy where an equal
    # probability is assigned to each token at each step of the equation
    # building
    def _pdf_eq_prob_tokens(self, z):

        pdfs = []
        num_consts_required = 1

        distr_const_idx = 0
        consts_params = self._target_policy.get_consts_params(z)

        for i, t in enumerate(z.tokens()):

            # Determine parameters of categorical distribution
            p, _ = self._determine_p(num_consts_required, z.tokens()[:i])

            # Determine probability of selecting token
            pdfs.append(p[t['id']])

            # If const was sampled from distribution
            if t['op'] == 'distr_const':
                pdfs.append(
                    scipy.stats.norm.pdf(
                        t['value'],
                        consts_params[distr_const_idx][0],
                        consts_params[distr_const_idx][1])
                )
                distr_const_idx += 1

            # Increase or decrease the number of constants required
            # depending on the sample token type
            if t['type'] == 'bin_op':
                num_consts_required += 1
            elif t['type'] == 'const':
                num_consts_required -= 1

        return math.prod(pdfs)

    # Determines pdf of policy that samples uniformly over all discrete
    # expressions
    def _pdf_enum_all(self, z):

        prob = 1.0 / len(self._all_valid_eqs)

        if z.num_distr_consts() > 0:

            # Get pdf of each individual constant sampled
            consts_params = self._target_policy.get_consts_params(z)

            # Multiply overall probability by pdf of const value
            for t, p in zip(z.distr_const_tokens(), consts_params):
                prob *= scipy.stats.norm.pdf(t['value'], p[0], p[1])

        return prob

    # Also returns pre_softmax_mask for token
    def _determine_p(self, num_consts_required, sampled_tokens):

        # First determine whether there are any token constraints
        pre_softmax_mask = self._target_policy.net_masks.compose_mask(
            self._target_policy.net_masks.determine_masks(
                self._max_num_tokens, sampled_tokens, num_consts_required
            )
        )

        # The default case - sample from all tokens equally
        if pre_softmax_mask is None:
            p = [1 / len(self._token_set)] * len(self._token_set)

        else:

            # Determine which tokens can be sampled from
            tokens_to_sample = [m > -0.1 for m in pre_softmax_mask]

            # Count number of tokens that can potentially be sampled from
            num_possible_tokens = sum(1 for t in tokens_to_sample if t)

            # Determine probability of each token
            p = [1 / num_possible_tokens if t else 0.0 for t in tokens_to_sample]

        return p, pre_softmax_mask
