#  Copyright (c) 2024-2025
import math

import torch
from matplotlib import pyplot as plt, cm
from torch import Tensor
from torch.nn import Parameter


class Exponential:
    def __init__(self, p=1e15):
        self._p = p

    def forward(self, x):
        return torch.exp(self._p * x)

    def inverse(self, x):
        return torch.log(x) / self._p


class Log:
    def __init__(self, p=1):
        self._p = p

    def inverse(self, x):
        return torch.exp(self._p * x)

    def forward(self, x):
        return torch.log(x) / self._p


class Identity:
    def forward(self, x):
        return x

    def inverse(self, x):
        return x


class Square:
    def forward(self, x):
        return x**2

    def inverse(self, x):
        return x.sqrt()


class PowerSumAggregation(torch.nn.Module):
    def __init__(self, t: float = 1.0, learn: bool = True):
        super().__init__()

        self._init_log_t = math.log(t)
        self.learn = learn

        self._inner_t = Parameter(torch.empty(1)) if learn else self._init_log_t
        self.reset_parameters()

    def reset_parameters(self):
        if isinstance(self._inner_t, Tensor):
            self._inner_t.data.fill_(self._init_log_t)

    @property
    def t(self):
        return torch.exp(self._inner_t)

    def forward(self, x, dim: int = -2) -> Tensor:
        # x = 1 + x
        x = x.pow(torch.exp(self._inner_t))
        x = x.sum(dim=dim, keepdim=True)

        return x


def abs_geometric_mean(x, dim):
    return x.abs().prod(dim=dim, keepdim=True).pow(1 / x.shape[dim])


def plot_gen_agg(gen_agg, device, range=5, steps=50, show=False):
    fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

    # Make data.
    if isinstance(range, int) or isinstance(range, float):
        range = (-range, range)
    X = torch.linspace(range[0], range[1], steps, device=device, dtype=torch.float)
    Y = torch.linspace(range[0], range[1], steps, device=device, dtype=torch.float)
    X, Y = torch.meshgrid(X, Y, indexing="ij")
    with torch.no_grad():
        Z = gen_agg(torch.stack([X, Y], dim=-1), dim=-1).squeeze(-1)

    # Plot the surface.
    surf = ax.plot_surface(
        X.cpu(), Y.cpu(), Z.cpu(), cmap=cm.coolwarm, linewidth=0, antialiased=False
    )

    # Customize the z axis.
    # ax.set_zlim(-1.01, 1.01)
    # ax.zaxis.set_major_locator(LinearLocator(10))
    # A StrMethodFormatter is used automatically
    ax.zaxis.set_major_formatter("{x:.02f}")

    # Add a color bar which maps values to colors.
    fig.colorbar(surf, shrink=0.5, aspect=5)
    if show:
        plt.show()
    return fig


if __name__ == "__main__":
    torch.manual_seed(0)

    power_sum = PowerSumAggregation(t=5, learn=True)

    # plot_gen_agg(gen_agg, device="cpu", range=(0, 1))

    def agg_mean(x, dim):
        return x.max(dim=dim, keepdim=True)[0] - x.mean(dim=dim, keepdim=True)

    plot_gen_agg(power_sum, device="cpu", range=(0, 2), show=True)

    # def f(x, dim):
    #     return x.max(dim=dim)[0]
    #
    # gen_agg = GenAgg(
    #     f=Square(),
    # )
    # train_to_match(gen_agg, torch.mean, "cpu", plot=True)
    # gen_agg = GenAgg()
    # laf = LAF()
    # train_to_match(laf, f, "cpu", plot=True)
