# coding: utf-8

import math
import torch

from dataclasses import dataclass


# Global configuration (modifiable after importing this file):

@dataclass
class Config:

    keep_logs_finite: bool = True
    "If True, log() always returns finite values. Otherwise, it does not."

    cast_all_logs_to_complex: bool = True
    "If True, log() always returns complex tensors. Otherwise, it does not."

    float_dtype: torch.dtype = torch.float32
    "Float dtype of real logarithms and components of complex logarithms."

config = Config()


# Helper functions for elementwise log(abs()) and exp():

class _CustomizedTorchAbs(torch.autograd.Function):
    """
    Applies torch.abs(), but with derivatives that are -1 for negative input
    values or 1 for non-negative ones, including at zero, for backpropagation.
    """
    generate_vmap_rule = True

    @staticmethod
    def forward(inp):
        out = torch.abs(inp)
        return out

    @staticmethod
    def setup_context(ctx, inp_tup, out):
        inp, = inp_tup
        ctx.save_for_backward(inp)

    @staticmethod
    def backward(ctx, grad_output):
        inp, = ctx.saved_tensors
        return grad_output * torch.where(inp < 0, -1, 1)


class _CustomizedTorchLog(torch.autograd.Function):
    """
    Applies torch.log(), but with derivatives that are always finite for
    backpropagation, and, if specified in config, keeping outputs finite.
    """
    generate_vmap_rule = True

    @staticmethod
    def forward(inp):
        log_inp = torch.log(inp)
        snn = torch.finfo(config.float_dtype).smallest_normal
        finite_floor = math.log(snn) * 2  # exps to zero in float_dtype
        keep_finite_idx = (log_inp < finite_floor) & config.keep_logs_finite
        out = torch.where(keep_finite_idx, finite_floor, log_inp)
        return out

    @staticmethod
    def setup_context(ctx, inp_tup, out):
        inp, = inp_tup
        ctx.save_for_backward(inp)

    @staticmethod
    def backward(ctx, grad_output):
        inp, = ctx.saved_tensors
        eps = torch.finfo(inp.dtype).eps
        return grad_output / (inp + eps)


class _CustomizedTorchExp(torch.autograd.Function):
    """
    Applies torch.exp(), but with derivatives that are always non-zero for
    backpropagation. Works with both float and complex input tensors.
    """
    generate_vmap_rule = True

    @staticmethod
    def forward(inp):
        out = torch.exp(inp)
        return out

    @staticmethod
    def setup_context(ctx, inp_tup, out):
        ctx.save_for_backward(out)

    @staticmethod
    def backward(ctx, grad_output):
        out, = ctx.saved_tensors
        eps = torch.finfo(out.real.dtype).eps
        signed_eps = torch.where(out.real < 0, -eps, eps)
        return grad_output * (out + signed_eps)


# Functions for computing elementwise log() and exp():

def log(x):
    "Elementwise log() of real tensor, i.e., a generalized order of magnitude."
    assert not x.is_complex(), "Input must be a float tensor, not a complex one."
    abs_x = _CustomizedTorchAbs.apply(x)
    log_abs_x = _CustomizedTorchLog.apply(abs_x)
    real = log_abs_x.to(config.float_dtype)
    x_is_neg = (x < 0)
    if torch.any(x_is_neg) or config.cast_all_logs_to_complex:
        return torch.complex(real, imag=x_is_neg.to(real.dtype) * torch.pi)
    else:
        return real

def exp(log_x):
    "Elementwise exp() of a tensor with generalized orders of magnitude."
    return _CustomizedTorchExp.apply(log_x).real


# Function for computing log-matmul-exp:

def log_matmul_exp(log_x1, log_x2):
    """
    Broadcastable log(exp(log_x1) @ exp(log_x2)).
    Input shapes: [..., d1, d2], [..., d2, d3].
    Output shape: [..., d1, d3].
    """
    c1 = log_x1.real.detach().max(dim=-1, keepdim=True).values.clamp(min=0)
    c2 = log_x2.real.detach().max(dim=-2, keepdim=True).values.clamp(min=0)
    x = torch.matmul(exp(log_x1 - c1), exp(log_x2 - c2))
    return log(x) + c1 + c2


# Other functions to be provided with camera-ready version of paper.
