import torch
from torch import nn
from src import utils
from src.laplace1 import utils as laplace_utils


class Net(nn.Module):
    def __init__(self, num_components_per_func=16):
        """
        nn.Module that applies multiple sets of transformations to 2D harmonic functions
        and combine them to a single weighted output.

        Args:
            num_components_per_func (int, optional):number of transformations that each function undergoes. Defaults to 32.
        """
        super().__init__()
        self.transformed_funcs = self.setup_transformed_funcs(
            [
                self.harmonic1,
                self.harmonic2,
                self.harmonic3,
                self.harmonic4,
                self.harmonic5,
                self.harmonic6,
                self.harmonic7,
                self.harmonic8,
            ],
            num_components_per_func,
        )
        self.num_funcs = len(self.transformed_funcs)
        self.train_iteration = 0

    def forward(self, x):
        output = 0.0
        for func in self.transformed_funcs:
            output = output + func(x)
        return output / self.num_funcs

    @staticmethod
    def setup_transformed_funcs(functions, num_components):
        output = []
        for func in functions:
            for orientation_preserving in [True, False]:
                output.append(
                    utils.TransformedFunc2D(
                        num_components=num_components,
                        orientation_preserving=orientation_preserving,
                        func=func,
                        num_outputs=1,
                        apply_rotation=True,
                    )
                )
        return nn.ModuleList(output)

    @staticmethod
    def harmonic1(x):
        x, y = x[:, 0], x[:, 1]
        return torch.sin(x) * torch.cosh(y)

    @staticmethod
    def harmonic2(x):
        x, y = x[:, 0], x[:, 1]
        return torch.sin(x**2 - y**2) * torch.cosh(2 * x * y)

    @staticmethod
    def harmonic3(x):
        x, y = x[:, 0], x[:, 1]
        return (
            torch.sin(x) ** 2 * torch.cosh(y) ** 2
            - torch.cos(x) ** 2 * torch.sinh(y) ** 2
        )

    @staticmethod
    def harmonic4(x):
        x, y = x[:, 0], x[:, 1]
        return torch.sin(torch.sin(x) * torch.cosh(y)) * torch.cosh(
            torch.cos(x) * torch.sinh(y)
        )

    @staticmethod
    def harmonic5(x):
        x, y = x[:, 0], x[:, 1]
        return torch.exp(x) * torch.cos(y)

    @staticmethod
    def harmonic6(x):
        x, y = x[:, 0], x[:, 1]
        return torch.exp(torch.sin(x) * torch.cosh(y)) * torch.cos(
            torch.cos(x) * torch.sinh(y)
        )

    @staticmethod
    def harmonic7(x):
        x, y = x[:, 0], x[:, 1]
        return torch.sin(torch.exp(x) * torch.cos(y)) * torch.cosh(
            torch.exp(x) * torch.sin(y)
        )

    @staticmethod
    def harmonic8(x):
        x, y = x[:, 0], x[:, 1]
        return torch.exp(x**2 - y**2) * torch.cos(2 * x * y)


def get_loss_function_and_network():
    collocation_points = laplace_utils.CollocationPoints()
    net = Net()

    def loss(net):
        x, target = collocation_points.dirichlet_boundary
        preds = net(x)
        return (preds - target).pow(2).mean()

    return loss, net
