import torch
from torch import nn
from src.heat2 import benchmark as heat_benchmark
from src.heat2 import utils as heat_utils


class Net(nn.Module):
    def __init__(self, num_components=64):
        super().__init__()
        self.angles = nn.Parameter(2 * torch.pi * torch.rand(num_components))
        self.x_shift = nn.Parameter(torch.rand(1, num_components))
        self.x_scale = nn.Parameter(10*torch.rand(1, num_components))
        self.y_shift = nn.Parameter(torch.rand(1, num_components))
        self.y_scale = nn.Parameter(10*torch.rand(1, num_components))
        self.t_shift = nn.Parameter(torch.rand(1, num_components))
        self.out_shift = nn.Parameter(torch.rand(1, num_components))
        self.weights = nn.Parameter(torch.rand(1, num_components))
        self.D = heat_benchmark.Benchmark().D
        self.num_components = num_components
        self.apply_rotation = False

    def rotation_matrices(self):
            sin = torch.sin(self.angles)
            cos = torch.cos(self.angles)
            out = torch.zeros(self.num_components, 2, 2)
            out[:, 0, 0] = cos
            out[:, 1, 1] = cos
            out[:, 1, 0] = -sin
            out[:, 0, 1] = sin
            return out

    def forward(self, x):
        if self.apply_rotation:
            xs, t = x[:, :2], x[:, 2].unsqueeze(1)

            matrices = self.rotation_matrices()
            xs = torch.matmul(matrices, xs.T).permute([2, 0, 1])
            x, y = xs[:, :, 0], xs[:, :, 1]
        else:
            x, y, t = x[:, 0].unsqueeze(1), x[:, 1].unsqueeze(1), x[:, 2].unsqueeze(1)
        
        x_component = torch.sin(self.x_scale * x + self.x_shift)
        y_component = torch.sin(self.y_scale * y + self.y_shift)
        t_scale = self.D * (self.x_scale**2 + self.y_scale**2)
        t_component = torch.exp(-t_scale * t + self.t_shift)
        return (
            (self.weights * x_component * y_component * t_component + self.out_shift)
            .mean(dim=1)
            .unsqueeze(1)
        )



def get_loss_function_and_network():
    collocation_points = heat_utils.CollocationPoints()
    net = Net()

    def loss(net):
        return heat_utils.boundary_loss(net, collocation_points)
    
    return loss, net

