import abc
import warnings

import numpy as np
import torch
from scipy.optimize import brentq


class BetaHandler(abc.ABC):
    def __init__(self, initial_beta: float = 0.0, max_beta: float = 1.0):
        """
        Abstract class that manages beta and beta changes in DLA.

        :param initial_beta: beta to start with. Must be between 0 and 1.
        """
        if initial_beta < 0:
            raise ValueError(f"Initial beta must be greater or equal than 0, but got {initial_beta}")
        self.beta = initial_beta
        self.old_beta = -1
        self.max_beta = max_beta

    @property
    def finished(self):
        """
        Check if DLA has finished running. This is true when the current and previous betas are equal to 1. This means
        that the previous stage was run with beta = 1, hence running another stage with beta = 1 is redundant.

        :return: True if current and previous beta are equal to 1, False otherwise.
        """
        return self.beta == self.old_beta == self.max_beta

    @abc.abstractmethod
    def step(self, *args, **kwargs) -> bool:
        """
        Set new beta.
        Set current beta as the old beta.

        :return: True if there was a change in beta and False otherwise.
        """
        raise NotImplementedError


class ESSBetaHandler(BetaHandler):
    def __init__(self, target_ess_fraction: float = 0.5, initial_beta: float = 0.0):
        """
        Beta handler that determines the next beta value according to an ESS-based search.
        """
        super().__init__(initial_beta=initial_beta)
        self.target_ess_fraction = target_ess_fraction

    def step(self,
             log_likelihood: torch.Tensor,
             log_prior: torch.Tensor,
             log_q: torch.Tensor,
             **kwargs) -> bool:
        """
        Set the new beta value.

        :param log_likelihood: log likelihood values for current particles.
        :param log_prior: log prior values for current particles.
        :param log_q: approximate log tempered posterior values for current particles (obtained via a normalizing flow).
        :param kwargs: unused.

        :return: True if there was a change in beta and False otherwise.
        """
        self.old_beta = self.beta

        target_ess = self.target_ess_fraction * len(log_likelihood)
        tmp = log_likelihood + log_prior - log_q

        @torch.no_grad()
        def objective(beta):
            ess = torch.exp(2 * torch.logsumexp(beta * tmp, dim=0) - torch.logsumexp(2.0 * beta * tmp, dim=0))
            return float(ess - target_ess)

        sign_a = np.sign(objective(self.beta))
        sign_b = np.sign(objective(1.0))

        if sign_a == sign_b:
            warnings.warn(f"Cannot find beta that would increase ESS, will not update beta")
            return False

        new_beta, res = brentq(
            f=objective,
            a=self.beta,
            b=1.0,
            xtol=1e-6,
            full_output=True,
            disp=False
        )

        if not res.converged:
            warnings.warn(f"Beta search did not converge")

        beta_changed = self.beta != new_beta
        self.beta = np.clip(new_beta, self.beta, 1.0)
        return beta_changed


class FixedIncrementBetaHandler(BetaHandler):
    def __init__(self, increment: float = 0.1, initial_beta: float = 0.0):
        """
        Beta handler that uses a fixed increment between beta values.

        :param increment: each new beta is computed as the old beta plus the increment.
        :param initial_beta: initial beta value.
        """
        super().__init__(initial_beta=initial_beta)
        if increment <= 0.0:
            raise ValueError(f"Increment must be positive, but got {increment}")
        self.increment = increment

    def step(self, *args, **kwargs) -> bool:
        """
        Set the new beta value.

        :param args: unused.
        :param kwargs: unused.

        :return: True if there was a change in beta and False otherwise.
        """
        self.old_beta = self.beta
        if self.beta == 1.0:
            return False
        self.beta = np.clip(self.beta + self.increment, self.beta, 1.0)
        return True


class FixedStageBetaHandler(FixedIncrementBetaHandler):
    def __init__(self, n_stages: int = 10, initial_beta: float = 0.0):
        """
        Beta handler that uses a fixed number of beta values.

        :param n_stages: number of beta stages between initial_beta and 1.
        :param initial_beta: initial beta value.
        """
        if n_stages <= 0:
            raise ValueError(f"n_stages must be positive, but got {n_stages}")
        increment = (1 - initial_beta) / n_stages
        super().__init__(initial_beta=initial_beta, increment=increment)


class LogarithmicBetaHandler(BetaHandler):
    def __init__(self, n_stages: int = 10, initial_beta: float = 0.0):
        """
        Beta handler that makes progressively smaller changes in beta values.

        :param n_stages: number of beta stages between initial_beta and 1.
        :param initial_beta: initial beta value.
        """
        super().__init__(initial_beta=initial_beta)
        assert n_stages >= 1
        self.n_stages = n_stages
        self.beta_values = [initial_beta]
        for _ in range(n_stages - 1):
            self.beta_values.append((self.beta_values[-1] + 1) / 2)
        self.beta_values.append(1.0)
        self.stage = 0

    def step(self, *args, **kwargs) -> bool:
        """
        Set the new beta value.

        :param args: unused.
        :param kwargs: unused.

        :return: True if there was a change in beta and False otherwise.
        """
        self.old_beta = self.beta_values[self.stage]
        self.stage += 1
        if self.stage == self.n_stages:
            return False
        self.beta = self.beta_values[self.stage]
        return True


class SingleStageBetaHandler(BetaHandler):
    def __init__(self, beta_value=1.0):
        """
        Beta handler for a single DLA stage with beta equal to 1.
        """
        super().__init__(initial_beta=beta_value, max_beta=beta_value)

    def step(self, *args, **kwargs) -> bool:
        """
        Writes the current beta value to old_beta. This is done to ensure DLA terminates after one stage.

        :param args: unused.
        :param kwargs: unused.

        :return: False.
        """
        self.old_beta = self.beta
        return False
