"""Stuff used to fit hyperparameters to place predictions in a KL range."""
import dataclasses
import random
from typing import Callable, Optional, Tuple

import numpy as np
import tensorflow as tf
from transformers import TFPreTrainedModel

from em.util import flat_pack


class InvalidKlFnOutputError(Exception):
    pass


@dataclasses.dataclass
class GenericKlTargeter:
    # [delta, lmbda] -> kl
    kl_fn: Callable[[float, float], float]

    kl_range: Tuple[float, float]
    delta_mag_range: Tuple[float, float]
    lmbda_working_range: Tuple[float, float] = (0.035, 1 - 0.035)

    backoff_factor: float = 2.0
    backoff_attempts: int = 10

    def __post_init__(self):
        assert len(self.kl_range) == 2
        assert 0 < self.min_kl < self.max_kl

        assert len(self.delta_mag_range) == 2
        assert 0 < self.min_delta_mag < self.max_delta_mag

        assert len(self.lmbda_working_range) == 2
        assert 0 < self.min_lmbda < self.max_lmbda < 1

        assert self.backoff_factor > 1

        self._last_delta = None
        self._last_lmbda = None

    #################################################################

    @property
    def min_kl(self):
        return self.kl_range[0]

    @property
    def max_kl(self):
        return self.kl_range[1]

    @property
    def min_delta_mag(self):
        return self.delta_mag_range[0]

    @property
    def max_delta_mag(self):
        return self.delta_mag_range[1]

    @property
    def min_lmbda(self):
        return self.lmbda_working_range[0]

    @property
    def max_lmbda(self):
        return self.lmbda_working_range[1]

    #################################################################

    def _get_random_delta(self):
        # Log uniform distribution.
        log_delta = random.uniform(np.log(self.min_delta_mag), np.log(self.max_delta_mag))
        return np.exp(log_delta)

    def _get_random_lmbda(self):
        # Uniform distribution
        return random.uniform(self.min_lmbda, self.max_lmbda)

    #################################################################

    def _evaluate(self, delta, lmbda):
        self._last_delta = delta
        self._last_lmbda = lmbda

        kl = self.kl_fn(delta, lmbda)

        if self.min_kl <= kl <= self.max_kl:
            return (0, kl)
        elif kl < self.min_kl:
            return (-1, kl)
        elif self.max_kl < kl:
            return (1, kl)
        else:
            raise InvalidKlFnOutputError('This condition should not be reachable.')

    def _kl_step_coeffs_gen(self, delta, lmbda, condition, i):
        assert condition != 0

        # if random.random() < 0.5:
        if i % 2:
            # Do delta.
            og_log_delta = np.log(delta)

            if condition < 0:
                log_diff = (np.log(self.max_delta_mag) - og_log_delta) / 2
            else:
                log_diff = (np.log(self.min_delta_mag) - og_log_delta) / 2

            for i in range(self.backoff_attempts):
                new_delta = np.exp(og_log_delta + log_diff)
                yield new_delta, lmbda
                log_diff /= self.backoff_factor

        else:
            # Do lmbda
            og_lmbda = lmbda

            if condition < 0:
                diff = (self.max_lmbda - og_lmbda) / 2
            else:
                diff = (self.min_lmbda - og_lmbda) / 2

            for i in range(self.backoff_attempts):
                new_lmbda = og_lmbda + diff
                yield delta, new_lmbda
                diff /= self.backoff_factor

    def search(
        self,
        max_iters: int,
        init_delta: Optional[float] = None,
        init_lmbda: Optional[float] = None,
    ):
        # NOTE: This is based on the assumption that the KL is monotonic in both
        # delta and lmbda.
        # ablating_fisher = self._get_ablating_fisher(component_index)
        delta = init_delta or self._get_random_delta()
        lmbda = init_lmbda or self._get_random_lmbda()

        for i in range(max_iters):
            condition, kl0 = self._evaluate(delta, lmbda)
            if condition == 0:
                return True

            coeffs = list(self._kl_step_coeffs_gen(delta, lmbda, condition, i))
            for j, (delta, lmbda) in enumerate(coeffs):
                cond, _ = self._evaluate(delta, lmbda)
                if cond == 0:
                    return True
                elif cond == condition:
                    break

        return False

