import torch
import numpy as np


def transform_one(x):
    return np.sign(x) * (np.sqrt(np.abs(x) + 1.0) - 1) + 0.001 * x

def atanh(x):
    return 0.5 * (torch.log(1 + x + 1e-6) - torch.log(1 - x + 1e-6))

def symlog(x):
    return torch.sign(x) * torch.log(torch.abs(x) + 1)

class DiscreteSupport(object):
    def __init__(self, **kwargs):
        if 'range' in kwargs:
            self.range = kwargs['range']
        else:
            self.range = (-1, 1)
        if 'bins' in kwargs:
            self.bins = kwargs['bins']
        else:
            self.bins = 20
        if 'softmax_temp' in kwargs:
            self.softmax_temp = kwargs['softmax_temp']
        else:
            self.softmax_temp = 1.0

    def scalar_to_vector(self, x):
        """ Reference from MuZerp: Appendix F => Network Architecture
        & Appendix A : Proposition A.2 in https://arxiv.org/pdf/1805.11593.pdf (Page-11)
        """
        x_min = self.range[0]
        x_max = self.range[1]
        bins = self.bins

        epsilon = 0.001

        x_min = transform_one(x_min)
        x_max = transform_one(x_max)
        
        scale = (x_max - x_min) / (bins - 1)

        sign = torch.ones(x.shape).float().to(x.device)
        sign[x < 0] = -1.0

        x = sign * (torch.sqrt(torch.abs(x) + 1) - 1) + epsilon * x
        x = x / scale

        x.clamp_(x_min / scale, x_max / scale - 1e-5)
        x = x - x_min / scale
        x_low_idx = x.floor()
        x_high_idx = x.ceil()
        p_high = x - x_low_idx
        p_low = 1 - p_high
        # print(f'p_high={p_high}', f'p_low={p_low}')

        target = torch.zeros(tuple(x.shape) + (bins,), dtype=p_high.dtype).to(x.device)
        target.scatter_(len(x.shape), x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1))
        target.scatter_(len(x.shape), x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1))

        return target
    
    def vector_to_scalar(self, logits):
        """ Inverse operation of scalar_to_vector to convert probability distribution back to scalar value
        """
        x_min = self.range[0]
        x_max = self.range[1]
        bins = self.bins
        softmax_temp = self.softmax_temp
        
        epsilon = 0.001

        x_min = transform_one(x_min)
        x_max = transform_one(x_max)
        scale = (x_max - x_min) / (bins - 1)
        x_range = np.arange(x_min, x_max + scale, scale)

        # Convert logits to probabilities and compute expected value
        # print(f'logits={logits}')
        logits = logits.to(dtype=torch.float64)  # Convert to float64 (double) for higher precision
        probs = torch.softmax(logits / softmax_temp, dim=-1)  # multiply logits by 10.0 to make distribution sharper
        # print(f'probs={probs}')
        support = torch.tensor(x_range, device=logits.device).expand(probs.shape)
        value = (support * probs).sum(-1, keepdim=True)

        # Inverse transform_one operation
        sign = torch.sign(value)
        abs_value = ((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) ** 2 - 1
        value = sign * abs_value

        return value


def softmax(logits):
    logits = np.asarray(logits)

    logits -= logits.max()
    logits = np.exp(logits)
    logits = logits / logits.sum()

    return logits