import torch
from src.laplace1 import benchmark
import functools
from src import utils


class CollocationPoints:
    def __init__(self, num_boundary=128, num_interior=1024):
        self.num_boundary = num_boundary
        self.num_interior = num_interior
        self.benchmark = benchmark.Benchmark()
        self.adaptive_points = None
        self.update_counter = 0
        self.update_frequency = 1000
        self.update_num = 32
        self.update_candidates = 1024

    @staticmethod
    def stack(x, y):
        return torch.stack([x, y], dim=1)

    @functools.cached_property
    def dirichlet_boundary(self):
        linspace = torch.linspace(0.0, 1.0, self.num_boundary)
        zero = torch.zeros_like(linspace)
        one = torch.ones_like(linspace)

        x0 = self.stack(zero, linspace)
        phix0 = self.benchmark(x0)

        x1 = self.stack(one, linspace)
        phix1 = self.benchmark(x1)

        y0 = self.stack(linspace, zero)
        phiy0 = self.benchmark(y0)

        y1 = self.stack(linspace, one)
        phiy1 = self.benchmark(y1)

        return torch.cat([x0, x1, y0, y1], dim=0), torch.cat([phix0, phix1, phiy0, phiy1], dim=0)

    @functools.cached_property
    def interior(self):
        x = torch.rand(self.num_interior, requires_grad=True)
        y = torch.rand(self.num_interior, requires_grad=True)
        return x, y

    @property
    def adaptive_interior(self):
        x0, y0 = self.interior
        x0.requires_grad = False
        y0.requires_grad = False
        x1, y1 = self.adaptive_points
        x = torch.cat([x0, x1])
        y = torch.cat([y0, y1])
        x.requires_grad = True
        y.requires_grad = True
        return x, y

    def update_adaptive_points(self, net):
        x = torch.rand(self.update_candidates, requires_grad=True)
        y = torch.rand(self.update_candidates, requires_grad=True)
        residuals = pde_loss(net, x, y, aggregate_output=False)
        keepers = torch.argsort(residuals, descending=True)[:self.update_num]
        x = x[keepers].detach()
        y = y[keepers].detach()
        if self.adaptive_points is not None:
            x0, y0 = self.adaptive_points
            x = torch.cat([x0, x])
            y = torch.cat([y0, y])
        self.adaptive_points = x, y

    def increment_adaptive_points(self, net):
        if self.update_counter % self.update_frequency == 0:
            self.update_adaptive_points(net)
        self.update_counter += 1


def pde_loss(net, x, y, aggregate_output=True):
    xs = torch.stack([x, y], dim=1)
    phi = net(xs)
    d2x = utils.d(utils.d(phi, x), x)
    d2y = utils.d(utils.d(phi, y), y)
    if aggregate_output:
        return (d2x + d2y).pow(2).mean()
    return (d2x + d2y).pow(2)
