import torch

from matplotlib import pyplot as plt
import numpy as np


class Benchmark:
    def __init__(self):
        self.eval_x = torch.rand(4096, 2)

    @staticmethod
    def __call__(x):
        # pi = np.pi
        # x, y = x[:, 0], x[:, 1]
        # output = (2 / pi) * torch.atan(torch.sin(pi * y) / (torch.sinh(pi * x)))
        # return output.unsqueeze(1)
        i = 1j
        x, y = x[:, 0], x[:, 1]
        z = x + i * y
        # c1 = 1.0 - 1.0 * i
        # z1 = (z * c1) ** 2
        trans_z = 1 / ((z - 1.2 - 0.5*i) * (z + 0.2 - 0.5*i) * (z - 0.5 + 0.2*i)*(z-0.5-1.2*i))
        return torch.sin(trans_z).real.unsqueeze(1)

    def plot_solution(self, side_length=128, net=None):
        x = torch.linspace(start=0.0, end=1.0, steps=side_length)
        x, y = torch.meshgrid(x, x)
        x = x.reshape(-1)
        y = y.reshape(-1)
        x = torch.stack([x, y], dim=1)
        if net is None:
            out = self.__call__(x).reshape(side_length, side_length).numpy()
        else:
            out = net(x).reshape(side_length, side_length).detach().numpy()
        fig, ax = plt.subplots()
        ax = ax.imshow(out)
        fig.colorbar(ax)
        return fig

    def evaluate_network(self, net):
        exact = self.__call__(self.eval_x)
        preds = net(self.eval_x)
        return (exact - preds).pow(2).mean().sqrt().item()
