import torch
from torch import Tensor
from torch.nn import functional as F
import logging

logger = logging.getLogger(__name__)


class MinMaxScaler:
    def __init__(self, dims=(0, -1)):
        self.mins = None
        self.maxs = None
        self.dims = dims

    def scale(self, x):
        self.mins = x
        for d in self.dims:
            self.mins = self.mins.min(axis=d, keepdim=True)[0]

        self.maxs = x
        for d in self.dims:
            self.maxs = self.maxs.max(axis=d, keepdim=True)[0]
        logger.info(x.shape, self.mins.shape)

        mask = self.mins == self.maxs
        self.mins[mask] = 0.0
        self.maxs[mask] = 1.0
        return (x - self.mins) / (self.maxs - self.mins)

    def invert(self, x):
        return (self.maxs - self.mins) * x + self.mins


# integrate (exp(1.0 / (1.0 + x ** 2)) - 1) dx, x=-5..5
int_minus5_5 = 3.77528


def foo(x):
    return (torch.exp(1.0 / (1.0 + x ** 2)) - 1) / int_minus5_5


def logistic(x):
    return 1.0 / (1.0 + torch.exp(-x))


def scaled_logistic(x, a, b, x0, scale):
    return a + (b - a) * logistic((x - x0) / scale)


def scaled_logistic_with_noise(x, a, b, x0, scale, noise_gt=0.5):
    r = torch.where(
        x < x0,
        a + (b - a) * logistic((x - x0) / scale),
        torch.clamp(
            torch.normal(
                a + (b - a) * logistic((x - x0) / scale),
                noise_gt * torch.normal(a + (b - a) * logistic((x - x0) / scale)),
            ),
            1.0e-2,
        ),
    )
    return r


def get_data_grid(ntime=10, nspace=100, sigma=0.1):
    xa, xb = -5.0, 5.0
    left, right = -5.0, 5.0

    a_scale = 2 - 0.2 * logistic(torch.linspace(left, right, ntime))
    composition = 1.0 - 0.5 * logistic(torch.linspace(left, right, ntime))
    xdata = xa + (xb - xa) * torch.rand(nspace)

    ydata0 = foo(xdata)
    ydata1 = 0.5 * foo((xdata - 5) / 0.25)
    ydata = a_scale[0] * (composition[0] * ydata0 + ydata1) + sigma * torch.randn_like(
        ydata0
    )
    return torch.stack([xdata, ydata], dim=-1)


def get_data_random(ndata=1000, seed=11, include_space_scaling=True):
    """

    :param ndata:
    :param seed:
    :param include_space_scaling:
    :return:
    """

    n_ground_truth = 101
    torch.manual_seed(seed)
    pa, pb = -5.0, 5.0
    space_a, space_b = -5.0, 5.0

    # pdata ~ energy coordinate
    n_grid_square = int(ndata ** 0.5)
    parray = torch.rand(ndata)
    parray = pa + (pb - pa) * parray
    parray_gt = pa + (pb - pa) * torch.linspace(0.0, 1.0, n_ground_truth)

    # signal : \tilde g_1(p)
    spectrum_signal = foo(parray + 2.5)
    spectrum_signal_gt = foo(parray_gt + 2.5)

    # background : \tilde g_0(p)
    spectrum_bg = foo((parray - 2.5) / 1.0)
    spectrum_bg_gt = foo((parray_gt - 2.5) / 1.0)

    # sdata ~ space coordinate
    sarray = torch.rand(ndata)
    sarray = space_a + (space_b - space_a) * sarray
    sarray_gt = space_a + (space_b - space_a) * torch.linspace(0.0, 1.0, n_ground_truth)

    proportion_space = logistic(sarray)
    proportion_space_gt = logistic(sarray_gt)

    norm_space = 1.0 + proportion_space
    norm_space_gt = 1.0 + proportion_space_gt

    # proportion : \tilde s(x)
    volume_space = space_b - space_a
    # norm_space = norm_space / (norm_space.sum() / ndata)
    # \tilde f = f / (<f> * V)
    norm_space = norm_space / (volume_space * norm_space.sum() / ndata)
    norm_space_gt = norm_space_gt / (
        volume_space * norm_space_gt.sum() / n_ground_truth
    )

    # proportions : \beta_k(x)
    pa, pb = proportion_space / (1.0 + proportion_space), 1.0 / (1.0 + proportion_space)
    pa_gt, pb_gt = proportion_space_gt / (1.0 + proportion_space_gt), 1.0 / (
        1.0 + proportion_space_gt
    )

    ydata = pa * spectrum_signal + pb * spectrum_bg
    if include_space_scaling:
        ydata *= norm_space

    coordinates = torch.stack([sarray, parray]).T

    return (
        coordinates,
        ydata.unsqueeze(-1),
        (parray_gt, spectrum_signal_gt, spectrum_bg_gt),
        (sarray_gt, norm_space_gt, pa_gt, pb_gt),
    )


def linear_aux(input, weight, bias=None):

    tens_ops = (input, weight)
    if not torch.jit.is_scripting():
        if any([type(t) is not Tensor for t in tens_ops]) and F.has_torch_function(
            tens_ops
        ):
            return F.handle_torch_function(F.linear, tens_ops, input, weight, bias=bias)
    if input.dim() >= 2 and bias is not None:
        # fused op is marginally faster
        # ret = torch.addmm(bias, input, weight.t())
        ret = torch.einsum("...k,...jk->...j", input, weight) + bias
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret


def linear_aux2(input, weight, bias=None):
    ret = torch.einsum("...k,jk->...j", input, weight)
    if bias is not None:
        ret += bias
    return ret
