"""Code related to general math expressions."""
import collections
import dataclasses
import itertools
import random
import sys
from typing import Optional, Sequence, Dict

import numpy as np
import sympy as sp

from . import constants
from . import conversion_util
from . import sympy_util as sp_util
from . import trees
from .misc_util import timeout, TimeoutError

from .cas import cas_abcs
from .cas import sympy_cas


OPERATOR_ARITIES = constants.OPERATOR_ARITIES
SUPPORTED_OPERATORS = constants.SUPPORTED_OPERATORS

DEFAULT_OPERATOR_WEIGHTS = constants.DEFAULT_OPERATOR_WEIGHTS

ALL_SYMPY_REWRITE_FUNCTIONS = constants.ALL_SYMPY_REWRITE_FUNCTIONS


@dataclasses.dataclass
class LeafProbabilities:
    p_variable: float = 0.75
    p_coefficient: float = 0.0
    p_integer: float = 0.25
    p_constant: float = 0.0

    def __post_init__(self):
        # Normal such that the probabilities sum to zero.
        total = self.p_variable + self.p_coefficient + self.p_integer + self.p_constant
        self.p_variable /= total
        self.p_coefficient /= total
        self.p_integer /= total
        self.p_constant /= total

    def sample(self) -> str:
        choices = ['variable', 'coefficient', 'integer', 'constant']
        p = [self.p_variable, self.p_coefficient, self.p_integer, self.p_constant]
        choice, = random.choices(choices, p)
        return choice


class ExpressionGenerator:

    def __init__(
        self,
        # TODO: Add default for most of these without defaults.
        max_ops: int,
        operator_weights: Optional[Dict[str, float]] = None,
        leaf_probabilities: LeafProbabilities = LeafProbabilities(),
        constant_names: Sequence[str] = ('pi', 'E'),
        # NOTE: There might be issues with stuff when we have more than 1 variable.
        variable_names: Sequence[str] = ('x',),
        n_coefficients: int = 10,
        max_int: int = 10_000,
        rewrite_functions: Sequence[str] = ALL_SYMPY_REWRITE_FUNCTIONS,
        cas: Optional[cas_abcs.CasAbc] = None,
    ):
        if operator_weights is None:
            operator_weights = DEFAULT_OPERATOR_WEIGHTS

        if cas is None:
            cas = sympy_cas.SympyCas()
        self.cas = cas

        self.max_ops = max_ops
        self.max_int = max_int

        self.rewrite_functions = rewrite_functions
        assert len(self.rewrite_functions) == len(set(self.rewrite_functions))
        assert all(f in ALL_SYMPY_REWRITE_FUNCTIONS for f in self.rewrite_functions)

        ###################################################
        # Set up operators.

        assert len(operator_weights) >= 1 and all(o in SUPPORTED_OPERATORS for o in operator_weights.keys())

        self.all_ops = [o for o in operator_weights.keys()]
        self.una_ops = [o for o in operator_weights.keys() if OPERATOR_ARITIES[o] == 1]
        self.bin_ops = [o for o in operator_weights.keys() if OPERATOR_ARITIES[o] == 2]

        self.all_ops_probs = np.array([w for _, w in operator_weights.items()]).astype(np.float64)
        self.una_ops_probs = np.array([w for o, w in operator_weights.items() if OPERATOR_ARITIES[o] == 1]).astype(np.float64)
        self.bin_ops_probs = np.array([w for o, w in operator_weights.items() if OPERATOR_ARITIES[o] == 2]).astype(np.float64)

        self.all_ops_probs /= self.all_ops_probs.sum()
        self.una_ops_probs /= self.una_ops_probs.sum()
        self.bin_ops_probs /= self.bin_ops_probs.sum()

        assert len(self.all_ops) == len(set(self.all_ops)) >= 1
        assert set(self.all_ops).issubset(SUPPORTED_OPERATORS)
        assert len(self.all_ops) == len(self.una_ops) + len(self.bin_ops)

        ###################################################
        # Set up symbols / elements.
    
        self.digits = [str(i) for i in range(10)]
        self.constants = constant_names
        self.variables = collections.OrderedDict(
            (v, sp.Symbol(v, real=True, nonzero=True))
            for v in variable_names
        )
        self.coefficients = collections.OrderedDict({
            f'a{i}': sp.Symbol(f'a{i}', real=True)
            for i in range(n_coefficients)
        })

        self.d_variable, = itertools.islice(self.variables.values(), 1)

        self.all_leaf_symbols = frozenset(itertools.chain(
            self.variables.keys(), self.constants, self.coefficients.keys()))

        self.n_variables = len(self.variables)
        self.n_coefficients = len(self.coefficients)

        # SymPy elements
        self.local_dict = {}
        for k, v in list(self.variables.items()) + list(self.coefficients.items()):
            assert k not in self.local_dict
            self.local_dict[k] = v

        ####################################################
        # Other stuff.

        self.leaf_probabilities = leaf_probabilities
        self.n_leaves = self._compute_n_leaves()

        # TODO: Read paper and see if I need to set these.
        # generation parameters
        self.nl = 1  # self.n_leaves
        self.p1 = 1  # len(self.una_ops)
        self.p2 = 1  # len(self.bin_ops)

        # initialize distribution for binary and unary-binary trees
        self.bin_dist = trees.generate_bin_dist(max_ops)
        self.ubi_dist = trees.generate_ubi_dist(max_ops, nl=self.nl, p1=self.p1, p2=self.p2)

        # # rewrite expressions
        # self.rewrite_functions = [x for x in params.rewrite_functions.split(',') if x != '']
        # assert len(self.rewrite_functions) == len(set(self.rewrite_functions))
        # assert all(x in ['expand', 'factor', 'expand_log', 'logcombine', 'powsimp', 'simplify'] for x in self.rewrite_functions)

        self.converter = conversion_util.Converter(
            all_operators=self.all_ops,
            all_leaf_symbols=self.all_leaf_symbols,
            rewrite_functions=self.rewrite_functions,
            variable_names=self.variables.keys(),
            local_dict=self.local_dict,
        )

    def _compute_n_leaves(self) -> int:
        leaf_probs = self.leaf_probabilities
        n_leaves = self.n_variables + self.n_coefficients
        if leaf_probs.p_integer > 0:
            # Multiply by two to handle positive and negative integers.
            n_leaves += 2 * self.max_int
        if leaf_probs.p_constant > 0:
            n_leaves += len(self.constants)
        return n_leaves

    ###########################################################################

    def _sample_n_ops(self):
        if random.randrange(0, 40) == 0:
            return random.randrange(0, 3)
        else:
            return random.randrange(3, self.max_ops + 1)

    def _sample_next_pos_ubi(self, n_empty, n_ops):
        """
        Sample the position of the next node (unary-binary case).
        Sample a position in {0, ..., `n_empty` - 1}, along with an arity.
        """
        assert n_empty > 0
        assert n_ops > 0
        probs = []
        for i in range(n_empty):
            probs.append((self.nl ** i) * self.p1 * self.ubi_dist[n_empty - i][n_ops - 1])
        for i in range(n_empty):
            probs.append((self.nl ** i) * self.p2 * self.ubi_dist[n_empty - i + 1][n_ops - 1])
        probs = [p / self.ubi_dist[n_empty][n_ops] for p in probs]
        probs = np.array(probs, dtype=np.float64)
        e, = random.choices(range(2 * n_empty), probs)
        arity = 1 if e < n_empty else 2
        e = e % n_empty
        return e, arity

    def _get_leaf(self):
        """Generate a leaf."""
        leaf_type = self.leaf_probabilities.sample()
        if leaf_type == 'variable':
            return [list(self.variables.keys())[random.randrange(0, self.n_variables)]]
        elif leaf_type == 'coefficient':
            return [list(self.coefficients.keys())[random.randrange(0, self.n_coefficients)]]
        elif leaf_type == 'integer':
            c = random.randrange(1, self.max_int + 1)
            c = c if random.randrange(0, 2) == 0 else -c
            return conversion_util.write_int(c)
        else:
            return [self.constants[random.randrange(0, len(self.constants))]]

    def _generate_expr(self, n_total_ops: int):
        """Create a tree with exactly `n_total_ops` operators."""
        stack = [None]
        n_empty = 1  # number of empty nodes
        l_leaves = 0  # left leaves - None states reserved for leaves
        t_leaves = 1  # total number of leaves (just used for sanity check)

        # create tree
        for nb_ops in range(n_total_ops, 0, -1):
            # next operator, arity and position
            skipped, arity = self._sample_next_pos_ubi(n_empty, nb_ops)

            if arity == 1:
                op, = random.choices(self.una_ops, self.una_ops_probs)
            else:
                op, = random.choices(self.bin_ops, self.bin_ops_probs)

            n_empty += OPERATOR_ARITIES[op] - 1 - skipped  # created empty nodes - skipped future leaves
            t_leaves += OPERATOR_ARITIES[op] - 1            # update number of total leaves
            l_leaves += skipped                           # update number of left leaves

            # update tree
            pos = [i for i, v in enumerate(stack) if v is None][l_leaves]
            stack = stack[:pos] + [op] + [None for _ in range(OPERATOR_ARITIES[op])] + stack[pos + 1:]

        # sanity check
        assert len([1 for v in stack if v in self.all_ops]) == n_total_ops
        assert len([1 for v in stack if v is None]) == t_leaves

        # create leaves
        leaves = [self._get_leaf() for _ in range(t_leaves)]
        random.shuffle(leaves)

        # insert leaves into tree
        for pos in range(len(stack) - 1, -1, -1):
            if stack[pos] is None:
                stack = stack[:pos] + leaves.pop() + stack[pos + 1:]
        assert len(leaves) == 0

        return stack

    ###########################################################################

    def can_generate_coefficients(self) -> bool:
        return self.n_coefficients > 0 and self.leaf_probabilities.p_coefficient > 0
    ###########################################################################

    def reduce_coefficients(self, expr):
        return sp_util.reduce_coefficients(expr, self.variables.values(), self.coefficients.values())

    def reindex_coefficients(self, expr):
        if self.n_coefficients == 0:
            return expr
        return sp_util.reindex_coefficients(expr, list(self.coefficients.values())[:self.n_coefficients])

    def extract_non_constant_subtree(self, expr):
        return sp_util.extract_non_constant_subtree(expr, self.variables.values())

    def simplify_const_with_coeff(self, expr, coeffs=None):
        if coeffs is None:
            coeffs = self.coefficients.values()
        for coeff in coeffs:
            expr = sp_util.simplify_const_with_coeff(expr, coeff)
        return expr

    ###########################################################################

    def _maybe_generate_expression(self):
        x = self.d_variable
        n_ops = self._sample_n_ops()
        try:
            # generate an expression and rewrite it,
            # avoid issues in 0 and convert to SymPy
            f_expr = self._generate_expr(n_ops)
            infix = self.converter.prefix_to_infix(f_expr)
            f = self.converter.infix_to_sympy(infix)

            # skip constant expressions
            if x not in f.free_symbols:
                return None

            # remove additive constant, re-index coefficients
            if random.randrange(2) == 0:
                f = sp_util.remove_root_constant_terms(f, x, 'add')
            if self.can_generate_coefficients():
                f = self.reduce_coefficients(f)
                f = self.simplify_const_with_coeff(f)
                f = self.reindex_coefficients(f)

            # # TODO: Maybe put this in its own function.
            # has_ead = self.cas.has_elementary_antiderivative(f, x)

            # # convert back to prefix
            # f_prefix = self.converter.sympy_to_prefix(f)

            # # skip too long sequences
            # if max(len(f_prefix), len(F_prefix)) + 2 > self.max_len:
            #     return None

            # # skip when the number of operators is too far from expected
            # real_nb_ops = sum(1 if op in self.OPERATORS else 0 for op in f_prefix)
            # if real_nb_ops < nb_ops / 2:
            #     return None

        except TimeoutError:
            raise
        except (ValueError, AttributeError, TypeError, OverflowError, NotImplementedError, conversion_util.UnknownSymPyOperator, conversion_util.ValueErrorExpression):
            return None
        except Exception as e:
            # print("An unknown exception of type {0} occurred in line {1} for expression \"{2}\". Arguments:{3!r}.".format(type(e).__name__, sys.exc_info()[-1].tb_lineno, infix, e.args))
            return None

        return f

    def maybe_generate_expression(self, seconds: Optional[float] = None):
        f = self._maybe_generate_expression
        if seconds is not None:
            f = timeout(seconds)(f)
        try:
            return f()
        except TimeoutError:
            return None

    def generate_expression(self, seconds_per_attempt: Optional[float] = None, max_attempts: int = 100):
        for _ in range(max_attempts):
            expr = self.maybe_generate_expression(seconds_per_attempt)
            if expr is not None:
                return expr
        # Ideally this should not happen.
        return None


###############################################################################


# TODO: Move some of the stuff from ExpressionGenerator to be here.
class ExpressionContext:
    pass
    # constants
    # variables
    # coefficients
    # d_variable


# Make an expression data class?
@dataclasses.dataclass
class Expression:
    ctx: ExpressionContext
