"""Examples using the exp(x + exp(x)) backbone."""
import abc
import csv
import dataclasses
import os
import random
from typing import Mapping

import sympy as sp

from em.datasets.antiderivative import misc_util

"""
I'm going to do the exp(x + exp(x)) dataset.

There are a lot of statistical biases here that I am either going to
ignore or use some arbitrary value for them.
    - Similar terms that don't have an EAD, e.g. e^e^x?

Things that influence the label:
    - The relation between a and b.
    - If adding g(x) to the expression, then whether g(x) has an EAD.

Things that don't influence the label:
    - Constant times the expression.
    - Adding a term in the sum in the first exponential (i.e. exp(x + exp(x) + c)).


d * exp(a * x + exp(b * x) + c) + g(x),
where multiplications can also be divided by the constants.


p(divide by constant vs multiply by constant) = 0.5
p(constant / function other than a,b are relevant identity) = 0.5
p(a = 1) = 0.25
p(b = 1 |a) = depends on sampled label

p(constant is small pos/negative integer, say <= 25) = 0.5 ?
p(constant is otherwise rational)
p(constant is irrational)

"""


@misc_util.timeout(1)
def _sympify(s):
    return sp.sympify(s)


CONSTANT_TYPES = frozenset({
    # Non-zero int with magnitude <= 20
    'small_int',
    # Non-zero int with magnitude <= 25_000
    'moderate_int',
    'rational',
    'irrational',
})

MAX_SMALL_INT = 20
MAX_MODERATE_INT = 25_000

MAX_RATIONAL_XATOR = 7500

MAX_AB_FACTOR = 25


def _preprocess_p_constant_type(p_constant_type):
    assert set(p_constant_type.keys()).issubset(CONSTANT_TYPES)
    denom = sum(p_constant_type.values())
    return {
        k: v / denom
        for k, v in p_constant_type.items()
    }


def _sample_binary(p_true: float):
    return random.random() <= p_true


def _is_positive_integer(expr):
    return expr.is_integer and expr.is_positive


class MaxRetriesExceededException(Exception):
    pass


###############################################################


class ExpressionsSource(abc.ABC):
    @abc.abstractmethod
    def get_expression(self, label) -> sp.Expr:
        raise NotImplementedError


class CsvExpressionsSource(ExpressionsSource):
    def __init__(self, filepath: str):
        self._true_exprs, self._false_exprs = self._read_in_expressions(filepath)

    def _read_in_expressions(self, filepath: str):
        with open(os.path.expanduser(filepath), newline='') as f:
            reader = csv.reader(f)
            rows = list(reader)
        true_exprs = [e for e, label in rows if label == '1']
        false_exprs = [e for e, label in rows if label == '0']
        return true_exprs, false_exprs

    def _try_get_expression(self, label):
        exprs = self._true_exprs if label else self._false_exprs
        expr_str = random.choice(exprs)
        return _sympify(expr_str)

    def get_expression(self, label) -> sp.Expr:
        for _ in range(120):
            try:
                return self._try_get_expression(label)
            except misc_util.TimeoutError:
                continue
        raise MaxRetriesExceededException


###############################################################


@dataclasses.dataclass
class _ExexParams:
    """
    d * exp(a * x + exp(b * x) + c) + g(x)
    """
    a: sp.Expr
    b: sp.Expr
    c: sp.Expr
    d: sp.Expr
    g: sp.Expr

    label: int

    def to_expr(self) -> sp.Expr:
        x = sp.sympify('x')
        return self.d * sp.exp(self.a * x + sp.exp(self.b * x) + self.c) + self.g


@dataclasses.dataclass
class ExexGenerator:
    """
    d * exp(a * x + exp(b * x) + c) + g(x)
    """

    expressions_source: ExpressionsSource

    p_true: float = 0.5
    p_false_from_plus_function: float = 0.3

    p_plus_function: float = 0.5

    p_no_constant: float = 0.5

    p_div: float = 0.5

    p_ab_factor_by_div: float = 0.25

    p_constant_type: Mapping[str, float] = None

    max_random_retry_attempts: int = 5_000

    def __post_init__(self):
        if self.p_constant_type is None:
            self.p_constant_type = {
                'small_int': 0.5,
                'moderate_int': 0.2,
                'rational': 0.15,
                'irrational': 0.15
            }
        self.p_constant_type = _preprocess_p_constant_type(self.p_constant_type)
        self._p_constant_type_keys = tuple(self.p_constant_type.keys())
        self._p_constant_type_values = tuple(self.p_constant_type.values())

    ###########################################################

    def _sample_small_int(self):
        mag = random.randrange(1, MAX_SMALL_INT + 1)
        sign = 1 if _sample_binary(0.5) else -1
        return sp.sympify(sign * mag)

    def _sample_moderate_int(self):
        mag = random.randrange(1, MAX_MODERATE_INT + 1)
        sign = 1 if _sample_binary(0.5) else -1
        return sp.sympify(sign * mag)

    def _sample_rational(self):
        numerator = random.randrange(1, MAX_RATIONAL_XATOR + 1)
        denominator = random.randrange(1, MAX_RATIONAL_XATOR + 1)
        sign = '' if _sample_binary(0.5) else '-'
        return sp.sympify(f'{sign}{numerator}/{denominator}')

    def _sample_constant(self):
        constant_type, = random.choices(self._p_constant_type_keys, self._p_constant_type_values)
        if constant_type == 'small_int':
            return self._sample_small_int()
        elif constant_type == 'moderate_int':
            return self._sample_moderate_int()
        elif constant_type == 'rational':
            return self._sample_rational()
        elif constant_type == 'irrational':
            raise NotImplementedError
        else:
            raise ValueError(constant_type)

    def _sample_maybe_constant(self):
        if _sample_binary(self.p_no_constant):
            return sp.sympify(1)
        else:
            return self._sample_constant()

    ###########################################################

    def _sample_ab_factor(self):
        return random.randrange(1, MAX_AB_FACTOR + 1)

    def _sample_ab_for_ead(self):
        # This is the label for the exex term, does not take into account the g(x) term.
        # a = n b
        
        base = self._sample_maybe_constant()

        ab_factor = self._sample_ab_factor()

        if _sample_binary(self.p_ab_factor_by_div):
            # base is a.
            a = base
            b = base / ab_factor
        else:
            # base is b.
            a = ab_factor * base
            b = base

        if _sample_binary(self.p_div):
            a, b = 1 / b, 1 / a

        return a, b

    def _try_sample_ab_for_no_ead(self):
        # This is the label for the exex term, does not take into account the g(x) term.
        a = self._sample_maybe_constant()
        b = self._sample_maybe_constant()
        if _is_positive_integer(a / b):
            return None
        return a, b

    def _sample_ab_for_no_ead(self):
        # This is the label for the exex term, does not take into account the g(x) term.
        for _ in range(self.max_random_retry_attempts):
            ab = self._try_sample_ab_for_no_ead()
            if ab is not None:
                a, b = ab
                return a, b

        raise MaxRetriesExceededException

    ###########################################################

    def _sample_g(self, label: int):
        if not _sample_binary(self.p_plus_function):
            return sp.sympify('0')
        return self.expressions_source.get_expression(label)

    def _sample_label(self):
        return int(_sample_binary(self.p_true))

    ###########################################################

    def _sample_params(self):
        label = self._sample_label()
        if not label and _sample_binary(self.p_false_from_plus_function):
            ab_label = 1
            force_false_g = True
        else:
            ab_label = label 
            force_false_g = False

        if ab_label:
            a, b = self._sample_ab_for_ead()
        else:
            a, b = self._sample_ab_for_no_ead()

        c = self._sample_maybe_constant()
        d = self._sample_maybe_constant()
        if _sample_binary(self.p_div):
            d = 1 / d

        if force_false_g:
            g = self.expressions_source.get_expression(0)
        elif label:
            g = self._sample_g(label)
        else:
            # NOTE: Not really sure if the 1 - p here is well-motivated
            # in any sense.
            g = self._sample_g(_sample_binary(1 - self.p_false_from_plus_function))

        return _ExexParams(a=a, b=b, c=c, d=d, g=g, label=label)

    def sample_expression(self):
        params = self._sample_params()
        return params.to_expr(), params.label
