from util.logger import logger

from typing import Optional, Union, Tuple

import numpy as np

import torch

import gc


class UnimodalBetaDistribution:
    def __init__(
        self, 

        a: Optional[float] = 0.0, 
        b: Optional[float] = 0.0, 

        dtype: Optional[str] = "float32", 
        device: Optional[str] = "cpu", 

        ver: Optional[str] = "torch"  # ["torch", "numpy"]
    ):
        self.ver = ver

        self.dtype = dtype
        self.device = device

        self.alpha = None
        self.beta = None

        # ---------= [Initialize Beta Distribution] =---------
        self.update(a = a, b = b)

        self.initialized = False

        self.init_mode = None

        # `__init__()` done
        pass

    
    def get_mode(
        self
    ) -> float:
        """
        Func:
            Get the mode of the current Beta distribution. 
            Return `None` if not initialized. 
        Ret:
            `mode` (`float`): The mode of the current Beta distribution. 
        """

        if self.initialized:
            mode = (self.alpha - 1) / (self.alpha + self.beta - 2)
        else:
            mode = None

        # `get_mode()` done
        return mode


    def update(
        self, 

        a: Optional[float] = 0.0, 
        b: Optional[float] = 0.0, 
    ):
        """
        Func:
            Update the parameter of the beta distribution. 
        """

        self.alpha = 1 + np.exp(a)
        self.beta = 1 + np.exp(b)

        self.alpha = float(self.alpha)
        self.beta = float(self.beta)

        self.initialized = True

        # `update()` done
        pass

    
    def sample(
        self, 

        shape: Tuple = (1, )
    ) -> Union["torch.Tensor", "np.ndarray"]:
        """
        Func:
            Sample from the beta distribution. 

        Ret:
            `sample_res` (`torch.Tensor` or `np.ndarray`): The sample. 
                sample_res.shape = shape. 
        """

        if self.ver == "torch":
            from torch.distributions import Beta

            beta_dist = Beta(self.alpha, self.beta)
            sample_res = beta_dist.sample(shape)

            sample_res = sample_res.to(
                dtype = self.dtype, 
                device = self.device
            )

            # clean up
            del beta_dist
            gc.collect()
            torch.cuda.empty_cache()
        elif self.ver == "numpy":
            sample_res = np.random.beta(
                a = self.alpha, b = self.beta, 
                size = shape
            )

            sample_res = sample_res.astype(self.dtype)

        # `sample()` done
        return sample_res


    def update_with_mode(
        self, 

        mode: Union["torch.Tensor", "np.ndarray"], 
        zeta: float = None, 

        clamp_eps: Optional[float] = 1e-8
    ):
        """
        Func:
            Update the parameter of the beta distribution with a mode. 
        """

        if zeta <= 0.0:
            raise ValueError(
                f"`zeta` should be strictly larger than 0.0, got `{zeta}`. "
            )
        
        if self.ver == "torch":
            if not isinstance(mode, torch.Tensor):
                mode = torch.tensor(mode)

            mode = mode.to(
                dtype = self.dtype, 
                device = self.device
            )
        elif self.ver == "numpy":
            if not isinstance(mode, np.ndarray):
                mode = np.array(mode)

            mode = mode.astype(self.dtype)

        # if self.ver == "torch":
        #     max_mode = torch.max(mode)
        # elif self.ver == "numpy":
        #     max_mode = np.max(mode)

        # if not (0 < max_mode < 1):
        if not (0 < mode < 1):
            logger(
                f"`mode` should be in the range (0.0, 1.0), "
                f"and it's clamped to (0.0 + eps, 1.0 - eps). ", 
                log_type = "warning"
            )

            if self.ver == "torch":
                mode = torch.clip(
                    mode, 
                    0.0 + clamp_eps, 1.0 - clamp_eps
                )
            elif self.ver == "numpy":
                mode = np.clip(
                    mode, 
                    0.0 + clamp_eps, 1.0 - clamp_eps
                )

        if self.init_mode is None:
            self.init_mode = mode

            if not isinstance(self.init_mode, float):
                self.init_mode = self.init_mode.item()

        if self.ver == "torch":
            a = torch.log(zeta)
            b = a + torch.log(
                (1 - mode) / mode
            )
        elif self.ver == "numpy":
            a = np.log(zeta)
            b = a + np.log(
                (1 - mode) / mode
            )

        a = float(a)
        b = float(b)

        self.update(a = a, b = b)

        # `update_with_mode()` done
        pass
