## Import libraries.
import torch
import torch.nn as nn
cos = torch.cos
sin = torch.sin


class PINN():
    def __init__(self, dnn_int, dnn_ext, args, inv_param):
        super(PINN, self).__init__()

        self.dnn_int = dnn_int
        self.dnn_ext = dnn_ext

        self.hist = [[], [], [], [], []]  # loss history; [0]:total loss, [1-4]: loss condi (1-4)
        self.hist_inv = [[] for i in range(args.order_interface+1)]  # inverse-coeffs history; only real parts

        self.hist_numeric = []
        
        self.args = args
        self.sigma_c = args.sigma_c
        self.sigma_m = args.sigma_m

        self.order_interface = args.order_interface
        self.arg_field_str = args.FIELD_list[args.arg_field]

        self.weights = args.weights
        self.R_bd = args.R_bd
        self.inv_param = inv_param

        self.p_nn_flag = args.p_nn_flag

    def get_u_int(self, data):
        u = self.dnn_int(data)
        return u


    def get_u_ext(self, data):
        u = self.dnn_ext(data)
        return u


    def get_interface(self, data):
        x = data[:, [0]]
        y = data[:, [1]]

        data = torch.cat((x, y), dim=1)

        if self.p_nn_flag == "False":
            r = self.R_bd
            t = torch.atan2(y, x)
            out = self.inv_param[0]

            for i in range(self.order_interface):
                deg = (i+1)
                if deg>=1:
                    out = out + 2 * (self.inv_param[2*deg-1]) * (r**deg) * (cos(deg*t)) - 2 * (self.inv_param[2*deg]) * (r**deg) * (sin(deg*t))

        elif self.p_nn_flag == "True":

            out = self.inv_param(data)

        # f = nn.ReLU()
        # out = f(out)

        return out


    def background_field(self, data):

        if self.arg_field_str == "along.x":
            out = data[:, [0]]
        elif self.arg_field_str == "along.y":
            out = data[:, [1]]

        ## General background fields

        elif self.arg_field_str == "along.0.5x+0.9y":
            out = 0.5*data[:, [0]] + 0.9*data[:, [1]]

        elif self.arg_field_str == "along.0.7x+0.7y":
            out = 0.7*data[:, [0]] + 0.7*data[:, [1]]

        elif self.arg_field_str == "along.0.25x-1y":
            out = 0.25*data[:, [0]] - 1.0*data[:, [1]]

        else:
            out = -1

        return out


    def get_divergence(self, arg_area, data):

        x = data[:, [0]]
        y = data[:, [1]]

        x = x.requires_grad_(True)
        y = y.requires_grad_(True)

        if arg_area == "interior":
            u = self.get_u_int(torch.cat([x, y], dim=1))

        elif arg_area == "exterior":
            u = self.get_u_ext(torch.cat([x, y], dim=1))

        u_x = torch.autograd.grad(
            u, x,
            grad_outputs=torch.ones_like(u),
            retain_graph=True,
            create_graph=True
        )[0]

        u_y = torch.autograd.grad(
            u, y,
            grad_outputs=torch.ones_like(u),
            retain_graph=True,
            create_graph=True
        )[0]

        u_xx = torch.autograd.grad(
            u_x, x,
            grad_outputs=torch.ones_like(u_x),
            retain_graph=True,
            create_graph=True
        )[0]

        u_yy = torch.autograd.grad(
            u_y, y,
            grad_outputs=torch.ones_like(u_y),
            retain_graph=True,
            create_graph=True
        )[0]

        return u_xx + u_yy


    def get_gradient(self, arg_area, data):

        x = data[:, [0]]
        y = data[:, [1]]

        x = x.requires_grad_(True)
        y = y.requires_grad_(True)

        if arg_area == "interior":
            u = self.get_u_int(torch.cat([x, y], dim=1))

        elif arg_area == "exterior":
            u = self.get_u_ext(torch.cat([x, y], dim=1))

        u_x = torch.autograd.grad(
            u, x,
            grad_outputs=torch.ones_like(u),
            retain_graph=True,
            create_graph=True
        )[0]

        u_y = torch.autograd.grad(
            u, y,
            grad_outputs=torch.ones_like(u),
            retain_graph=True,
            create_graph=True
        )[0]

        gradient = torch.cat((u_x, u_y), dim=1)

        return gradient


    def get_loss_condi_1(self, interior_psi, exterior_psi):
        #### These inputs are all the image of conformal map.

        res_int = (self.sigma_c *self.get_divergence("interior", interior_psi))
        res_ext = (self.sigma_m *self.get_divergence("exterior", exterior_psi))

        loss_int = torch.mean(torch.square(res_int))
        loss_ext = torch.mean(torch.square(res_ext))

        loss = (loss_int + loss_ext)
        self.hist[1].append(loss.item())

        return loss


    def get_loss_condi_2(self, boundary_disk, normal_psi, trace_int_psi, trace_ext_psi):

        interface = self.get_interface(boundary_disk)
        u_trace_int = self.get_u_int(trace_int_psi)
        u_trace_ext = self.get_u_ext(trace_ext_psi)

        LHS = interface * (u_trace_ext - u_trace_int)

        dot_first = self.get_gradient("exterior", trace_ext_psi)[:, [0]]*normal_psi[:, [0]]
        dot_second = self.get_gradient("exterior", trace_ext_psi)[:, [1]]*normal_psi[:, [1]]
        RHS = (self.sigma_m) * (dot_first + dot_second)

        if self.p_nn_flag == "False":
            pass

        elif self.p_nn_flag == "True":
            zero_dummy = 0 * RHS
            RHS = torch.cat((RHS, zero_dummy), dim=1)

        loss = torch.mean(torch.square(LHS - RHS))
        self.hist[2].append(loss.item())

        return loss

    def get_loss_condi_3(self, normal_psi, trace_int_psi, trace_ext_psi):

        dot_first_lhs = self.get_gradient("exterior", trace_ext_psi)[:, [0]]*normal_psi[:, [0]]
        dot_second_lhs = self.get_gradient("exterior", trace_ext_psi)[:, [1]]*normal_psi[:, [1]]
        LHS = (self.sigma_m) * (dot_first_lhs + dot_second_lhs)

        dot_first_rhs = self.get_gradient("interior", trace_int_psi)[:, [0]]*normal_psi[:, [0]]
        dot_second_rhs = self.get_gradient("interior", trace_int_psi)[:, [1]]*normal_psi[:, [1]]
        RHS = (self.sigma_c) * (dot_first_rhs + dot_second_rhs)

        loss = torch.mean(torch.square(LHS - RHS))
        self.hist[3].append(loss.item())

        return loss


    def get_loss_condi_4(self, exterior_psi):

        background = self.background_field(exterior_psi)
        u_exterior = self.get_u_ext(exterior_psi)
        loss = torch.mean(torch.square(u_exterior - background))
        self.hist[4].append(loss.item())

        return loss


    def get_loss(self, interior_psi, boundary_disk, exterior_psi, normal_psi, trace_int_psi, trace_ext_psi):

        loss_condi_1 = self.get_loss_condi_1(interior_psi, exterior_psi)
        loss_condi_2 = self.get_loss_condi_2(boundary_disk, normal_psi, trace_int_psi, trace_ext_psi)
        loss_condi_3 = self.get_loss_condi_3(normal_psi, trace_int_psi, trace_ext_psi)
        loss_condi_4 = self.get_loss_condi_4(exterior_psi)

        loss = (self.weights[0])*loss_condi_1 + (self.weights[1])*loss_condi_2 + (self.weights[2])*loss_condi_3 + (self.weights[3])*loss_condi_4
        self.hist[0].append(loss.item())

        if self.args.p_nn_flag == "False":
            self.hist_inv[0].append(self.inv_param[0].detach().cpu().item())
            for i in range(1, self.order_interface+1):
                self.hist_inv[i].append(self.inv_param[2*i-1].detach().cpu().item())

        return loss

