import torch
from src.burgers import benchmark
import functools
from src import utils


class CollocationPoints:
    def __init__(self, num_boundary=64, num_interior=1024):
        self.num_boundary = num_boundary
        self.num_interior = num_interior
        self.adaptive_points = None
        self.update_counter = 0
        self.update_frequency = 1000
        self.update_num = 32
        self.update_candidates = 1024

    @staticmethod
    def stack(x, t):
        return torch.stack([x, t], dim=1)

    @staticmethod
    def ic(x):
        return torch.exp(-50 * (x - 0.6) ** 2) - torch.exp(-50 * (x - 0.4) ** 2)

    @functools.cached_property
    def dirichlet_boundary(self):
        fx = torch.linspace(0.0, 1.0, self.num_boundary)

        zero = torch.zeros_like(fx)
        one = torch.ones_like(fx)

        t0 = self.stack(fx, zero)
        x0 = self.stack(zero, fx)
        x1 = self.stack(one, fx)

        ic = self.ic(fx)

        return (torch.cat([t0, x0, x1], dim=0), torch.cat([ic, zero, zero], dim=0))

    @functools.cached_property
    def interior(self):
        x = torch.rand(self.num_interior, requires_grad=True)
        t = torch.rand(self.num_interior, requires_grad=True)
        return x, t

    @property
    def adaptive_interior(self):
        x0, t0 = self.interior
        x0.requires_grad = False
        t0.requires_grad = False
        x1, t1 = self.adaptive_points
        x = torch.cat([x0, x1])
        t = torch.cat([t0, t1])
        x.requires_grad = True
        t.requires_grad = True
        return x, t

    def update_adaptive_points(self, net, k):
        x = torch.rand(self.update_candidates, requires_grad=True)
        t = torch.rand(self.update_candidates, requires_grad=True)
        residuals = pde_loss(net, x, t, k, aggregate_output=False)
        keepers = torch.argsort(residuals, descending=True)[: self.update_num]
        x = x[keepers].detach()
        t = t[keepers].detach()
        if self.adaptive_points is not None:
            x0, t0 = self.adaptive_points
            x = torch.cat([x0, x])
            t = torch.cat([t0, t])
        self.adaptive_points = x, t

    def increment_adaptive_points(self, net, k):
        if self.update_counter % self.update_frequency == 0:
            self.update_adaptive_points(net, k)
        self.update_counter += 1


def pde_loss(net, x, t, k, aggregate_output=True):
    xs = torch.stack([x, t], dim=1)
    u = net(xs).squeeze()
    ut = utils.d(u, t)
    ux = utils.d(u, x)
    uxx = utils.d(ux, x)
    if aggregate_output:
        return (ut + u * ux - k * uxx).pow(2).mean()
    return (ut + u * ux - k * uxx).pow(2)
