import torch
from torch import nn
from src.navierstokes import benchmark
from matplotlib import pyplot as plt
import functools
from src import utils


class MLP(nn.Module):
    def __init__(self, num_outputs):
        super().__init__()
        self.num_outputs = num_outputs
        self.net = nn.Sequential(
            nn.Linear(2, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, num_outputs),
        )

    def forward(self, x, y):
        out = self.net(torch.stack([x, y], dim=1))
        if self.num_outputs == 1:
            return out[:, 0]
        return out[:, 0], out[:, 1]


class AdaptiveActivationMLP(nn.Module):
    def __init__(self, num_outputs):
        super().__init__()
        self.num_outputs = num_outputs
        self.fc1 = nn.Linear(2, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, 32)
        self.fc4 = nn.Linear(32, num_outputs)
        self.a = nn.Parameter(torch.ones(1, 1))
        self.n = 10

    def forward(self, x, y):
        inputs = torch.stack([x, y], dim=1)
        out = self.a * self.n * self.fc1(inputs)
        out = torch.tanh(out)
        out = self.a * self.n * self.fc2(out)
        out = torch.tanh(out)
        out = self.a * self.n * self.fc3(out)
        out = torch.tanh(out)
        out = self.fc4(out)
        if self.num_outputs == 1:
            return out[:, 0]
        return out[:, 0], out[:, 1]


class CollocationPoints:
    def __init__(self, num_interior=1024):
        self.num_boundary = 100
        self.benchmark = benchmark.Benchmark()
        self.num_interior = num_interior
        self.adaptive_points = None
        self.update_counter = 0
        self.update_frequency = 1000
        self.update_num = 32
        self.update_candidates = 1024

    @staticmethod
    def not_in_pipe(x, y, pipe_x, pipe_y):
        x_d = x - pipe_x
        y_d = y - pipe_y
        return (x_d.pow(2) + y_d.pow(2)) > 0.05**2

    @functools.cached_property
    def inlet(self):
        y = torch.linspace(0.0, 0.1, self.num_boundary, requires_grad=True)
        x = torch.zeros_like(y, requires_grad=True)
        return x, y

    @functools.cached_property
    def periodic(self):
        x0 = torch.linspace(0.0, 0.1, 100)
        x1 = torch.linspace(0.2, 0.5, 300)
        top_x = torch.cat([x0, x1])
        top_y = 0.1 * torch.ones_like(top_x)

        x2 = torch.linspace(0.0, 0.3, 300)
        x3 = torch.linspace(0.4, 0.5, 100)
        bot_x = torch.cat([x2, x3])
        bot_y = torch.zeros_like(bot_x)

        x = torch.cat([top_x, bot_x])
        y = torch.cat([top_y, bot_y])
        x.requires_grad = True
        y.requires_grad = True
        return x, y

    @functools.cached_property
    def zero_velocity(self):
        theta = torch.linspace(0.0, 2 * torch.pi, 200)
        r2 = 0.05
        circle_1_x = r2 * torch.sin(theta) + 0.15
        circle_1_y = r2 * torch.cos(theta) + 0.1
        in_domain = circle_1_y < 0.1
        circle_1_x = circle_1_x[in_domain]
        circle_1_y = circle_1_y[in_domain]

        circle_2_x = r2 * torch.sin(theta) + 0.35
        circle_2_y = r2 * torch.cos(theta)
        in_domain = circle_2_y > 0.0
        circle_2_x = circle_2_x[in_domain]
        circle_2_y = circle_2_y[in_domain]

        x = torch.cat([circle_1_x, circle_2_x])
        y = torch.cat([circle_1_y, circle_2_y])

        x.requires_grad = True
        y.requires_grad = True
        return x, y

    @functools.cached_property
    def outlet(self):
        y = torch.linspace(0.0, 0.1, self.num_boundary, requires_grad=True)
        x = torch.zeros_like(y) + 0.5
        x.requires_grad = True
        return x, y

    def sample_interior(self, num_points):
        x = 0.5 * torch.rand(num_points)
        y = 0.1 * torch.rand(num_points)
        not_in_pipe_1 = self.not_in_pipe(x, y, 0.15, 0.1)
        x = x[not_in_pipe_1]
        y = y[not_in_pipe_1]
        not_in_pipe_2 = self.not_in_pipe(x, y, 0.35, 0.0)
        x = x[not_in_pipe_2]
        y = y[not_in_pipe_2]
        return x, y

    @functools.cached_property
    def interior(self):
        x, y = self.sample_interior(self.num_interior)
        x.requires_grad = True
        y.requires_grad = True
        return x, y

    @staticmethod
    def check_points(points):
        fig, ax = plt.subplots()
        x, y = points
        ax.scatter(x.detach().numpy(), y.detach().numpy())
        return fig

    @property
    def adaptive_interior(self):
        x0, y0 = self.interior
        x0.requires_grad = False
        y0.requires_grad = False
        x1, y1 = self.adaptive_points
        x = torch.cat([x0, x1])
        y = torch.cat([y0, y1])
        x.requires_grad = True
        y.requires_grad = True
        return x, y

    def update_adaptive_points(self, net):
        x, y = self.sample_interior(self.update_candidates)
        x.requires_grad = True
        y.requires_grad = True
        divergence = divergence_loss(net, x, y, aggregate_output=False)
        ns = ns_loss(net, x, y, aggregate_output=False)
        residuals = divergence + ns
        keepers = torch.argsort(residuals, descending=True)[:self.update_num]
        x = x[keepers].detach()
        y = y[keepers].detach()
        if self.adaptive_points is not None:
            x0, y0 = self.adaptive_points
            x = torch.cat([x0, x])
            y = torch.cat([y0, y])
        self.adaptive_points = x, y

    def increment_adaptive_points(self, net):
        if self.update_counter % self.update_frequency == 0:
            self.update_adaptive_points(net)
        self.update_counter += 1


def ns_loss(net, x, y, aggregate_output=True):
    u, v = net.velocity(x, y)
    p = net.pressure(x, y)
    mu = 1e-5

    ux = utils.d(u, x)
    uy = utils.d(u, y)

    uxx = utils.d(ux, x)
    uyy = utils.d(uy, y)

    vx = utils.d(v, x)
    vy = utils.d(v, y)

    vxx = utils.d(vx, x)
    vyy = utils.d(vy, y)

    px = utils.d(p, x)
    py = utils.d(p, y)

    eq1 = u * ux + v * uy - mu * (uxx + uyy) + px
    eq2 = u * vx + v * vy - mu * (vxx + vyy) + py
    if aggregate_output:
        return eq1.pow(2).mean() + eq2.pow(2).mean()
    return eq1.pow(2) + eq2.pow(2)


def divergence_loss(net, x, y, aggregate_output=True):
    u, v = net.velocity(x, y)
    ux = utils.d(u, x)
    vy = utils.d(v, y)
    if aggregate_output:
        return (ux + vy).pow(2).mean()
    return (ux + vy).pow(2)


def dirichlet_loss(net, collocation_points: CollocationPoints):
    loss = 0.0

    # inlet
    x, y = collocation_points.inlet
    u, v = net.velocity(x, y)
    loss = loss + (u - 0.1).pow(2).mean() + v.pow(2).mean()

    # outlet
    x, y = collocation_points.outlet
    p = net.pressure(x, y)
    loss = loss + p.pow(2).mean()

    # periodic
    x, y = collocation_points.periodic
    u, v = net.velocity(x, y)
    loss = loss + v.pow(2).mean()

    # zero
    x, y = collocation_points.zero_velocity
    u, v = net.velocity(x, y)
    loss = loss + u.pow(2).mean() + v.pow(2).mean()

    return loss
