import torch
import functools
from src import utils


class CollocationPoints:
    def __init__(self, num_boundary=64, num_interior=1024):
        self.num_boundary = num_boundary
        self.num_interior = num_interior
        self.adaptive_points = None
        self.update_counter = 0
        self.update_frequency = 1000
        self.update_num = 32
        self.update_candidates = 1024

    @staticmethod
    def stack(x, y, t):
        return torch.stack([x, y, t], dim=1)

    @staticmethod
    def ic(x, y):
        return torch.sqrt(
            (torch.sin(5 * torch.pi * x)**2 + torch.cos(3 * torch.pi * y)**2)
            * torch.exp(5*(-(x-0.5)**2 - (y-0.5)**2))
        )

    @functools.cached_property
    def dirichlet_boundary(self):
        linspace = torch.linspace(0.0, 1.0, self.num_boundary)
        f1, f2 = torch.meshgrid(linspace, linspace)
        f1 = f1.reshape(-1)
        f2 = f2.reshape(-1)

        zero = torch.zeros_like(f1)
        one = torch.ones_like(f1)

        t0 = self.stack(f1, f2, zero)
        x0 = self.stack(zero, f1, f2)
        x1 = self.stack(one, f1, f2)

        ic = self.ic(f1, f2)

        return (torch.cat([t0, x0, x1], dim=0),
                torch.cat([ic, zero, zero], dim=0).unsqueeze(1))

    @functools.cached_property
    def neumann_boundary(self):
        linspace = torch.linspace(0.0, 1.0, self.num_boundary)
        f1, f2 = torch.meshgrid(linspace, linspace)
        f1 = f1.reshape(-1)
        f2 = f2.reshape(-1)

        zero = torch.zeros_like(f1)
        one = torch.ones_like(f1)

        y0 = self.stack(f1, zero, f2)
        y1 = self.stack(f1, one, f2)

        return (torch.cat([y0, y1], dim=0),
                torch.cat([zero, zero], dim=0).unsqueeze(1))

    @functools.cached_property
    def interior(self):
        x = torch.rand(self.num_interior, requires_grad=True)
        y = torch.rand(self.num_interior, requires_grad=True)
        t = torch.rand(self.num_interior, requires_grad=True)
        return x, y, t

    @property
    def adaptive_interior(self):
        x0, y0, t0 = self.interior
        x0.requires_grad = False
        y0.requires_grad = False
        t0.requires_grad = False
        x1, y1, t1 = self.adaptive_points
        x = torch.cat([x0, x1])
        y = torch.cat([y0, y1])
        t = torch.cat([t0, t1])
        x.requires_grad = True
        y.requires_grad = True
        t.requires_grad = True
        return x, y, t

    def update_adaptive_points(self, net, D):
        x = torch.rand(self.update_candidates, requires_grad=True)
        y = torch.rand(self.update_candidates, requires_grad=True)
        t = torch.rand(self.update_candidates, requires_grad=True)
        residuals = pde_loss(net, x, y, t, D, aggregate_output=False)
        keepers = torch.argsort(residuals, descending=True)[:self.update_num]
        x = x[keepers].detach()
        y = y[keepers].detach()
        t = t[keepers].detach()
        if self.adaptive_points is not None:
            x0, y0, t0 = self.adaptive_points
            x = torch.cat([x0, x])
            y = torch.cat([y0, y])
            t = torch.cat([t0, t])
        self.adaptive_points = x, y, t

    def increment_adaptive_points(self, net, D):
        if self.update_counter % self.update_frequency == 0:
            self.update_adaptive_points(net, D)
        self.update_counter += 1


def boundary_loss(net, collocation_points: CollocationPoints):
    xyt, target = collocation_points.dirichlet_boundary
    xyt.requires_grad = True
    out = net(xyt)
    dirichlet_loss = (out - target).pow(2).mean()
    xyt, target = collocation_points.neumann_boundary
    xyt.requires_grad = True
    out = net(xyt)
    dout_dy = utils.d(out, xyt)[:, 1].unsqueeze(1)
    neumann_loss = (dout_dy - target).pow(2).mean()
    return dirichlet_loss + neumann_loss


def pde_loss(net, x, y, t, D, aggregate_output=True):
    xs = torch.stack([x, y, t], dim=1)
    phi = net(xs)
    d2x = utils.d(utils.d(phi, x), x)
    d2y = utils.d(utils.d(phi, y), y)
    dt = utils.d(phi, t)
    if aggregate_output:
        return (dt - D*(d2x + d2y)).pow(2).mean()
    return (dt - D*(d2x + d2y)).pow(2)
