## Import libraries.
import math
import numpy as np

import torch
import torch.nn.functional as F

pi = math.pi
cos = torch.cos
sin = torch.sin


class CLOUD():
    def __init__(self, arg_conformal, eps_bd, Conformal_order, custom_coeffs, fixed):
        super(CLOUD, self).__init__()

        self.arg_conformal = arg_conformal
        self.eps_bd = eps_bd
        self.Conformal_order = Conformal_order
        self.custom_coeffs = custom_coeffs
        self.fixed = fixed


    def meshgrid_tensor(self, xspace, yspace):
        X, Y = np.meshgrid(xspace, yspace)
        data_grid = np.vstack([X.flatten(), Y.flatten()]).T
        data_grid_tensor = torch.tensor(data_grid)
        return data_grid_tensor


    def generate_disk_data(self, L_dom, R_bd, S_carte, S_boundary):
        #### generate interior and exterior over the pre-image.
        if self.fixed == 1:
            xspace = np.linspace(-L_dom, L_dom, S_carte).astype(np.float32)
            yspace = np.linspace(-L_dom, L_dom, S_carte).astype(np.float32)
        elif self.fixed == 0:
            xspace = np.random.uniform(-L_dom, L_dom, S_carte).astype(np.float32)
            yspace = np.random.uniform(-L_dom, L_dom, S_carte).astype(np.float32)

        w_grids = self.meshgrid_tensor(xspace, yspace)

        N = w_grids.shape[0]

        list_concat = []

        for i in range(N):
            x_i = w_grids[[i], [0]].reshape(-1, 1)
            y_i = w_grids[[i], [1]].reshape(-1, 1)
            if (x_i**2 + y_i**2) > R_bd**2:
                list_concat.append(w_grids[[i], :])
        exterior_disk = torch.stack(list_concat).reshape(-1, 2)

        #### generate boundary data over the pre-image
        if self.fixed == 1:
            tspace = np.linspace(0, 2*pi, S_boundary+1).astype(np.float32)
        elif self.fixed == 0:
            tspace = np.random.uniform(0, 2*pi, S_boundary+1).astype(np.float32)

        tspace.sort()
        rspace = np.array([R_bd]).astype(np.float32)
        boundary_disk_polar = self.meshgrid_tensor(rspace, tspace)
        boundary_disk = self.polar_to_carte(boundary_disk_polar)
        return boundary_disk, exterior_disk


    def carte_to_polar(self, data):
        N = data.shape[0]
        list_concat = []
        for i in range(N):
            x_i = data[[i], [0]].reshape(-1, 1)
            y_i = data[[i], [1]].reshape(-1, 1)

            r_i = torch.sqrt(x_i**2 + y_i**2)
            t_i = torch.atan2(y_i, x_i)

            polar_i = torch.cat((r_i, t_i), dim=1)
            list_concat.append(polar_i)
        polar = torch.stack(list_concat).reshape(-1, 2)
        return polar

    def polar_to_carte(self, data):
        N = data.shape[0]
        list_concat = []
        for i in range(N):
            r_i = data[[i], [0]].reshape(-1, 1)
            t_i = data[[i], [1]].reshape(-1, 1)
            x_i = r_i*cos(t_i)
            y_i = r_i*sin(t_i)
            carte_i = torch.cat((x_i, y_i), dim=1)
            list_concat.append(carte_i)
        carte = torch.stack(list_concat).reshape(-1, 2)
        return carte


    def conformal_coefficients(self):

        a = torch.zeros(self.Conformal_order+1, 1)
        
        if self.arg_conformal == "identity":
            pass

        elif self.arg_conformal == "ellipse":
            a[1][0] = 0.25

        elif self.arg_conformal == "fish":  
            a[1][0] = 0.25
            a[2][0] = 0.125
            a[3][0] = 0.1

        elif self.arg_conformal == "kite":
            a[1][0] = 0.1
            a[2][0] = 0.25
            a[3][0] = -0.05
            a[4][0] = 0.05
            a[5][0] = -0.04
            a[6][0] = 0.02
        
        elif self.arg_conformal == "custom":
            for j in range(len(self.custom_coeffs)):
                a[j+1][0] = self.custom_coeffs[j]

        return a

    def conformal_map(self, data):

        a = self.conformal_coefficients()

        data = self.carte_to_polar(data)
        N = data.shape[0]
        list_concat = []
        for i in range(N):
            r_i = data[[i], [0]].reshape(-1, 1)
            t_i = data[[i], [1]].reshape(-1, 1)

            x_out = (r_i)*(cos(t_i))
            y_out = (r_i)*(sin(t_i))

            for j in range(1, self.Conformal_order+1):
            
                x_out = x_out + (a[j][0])*(r_i**(-j))*(cos((-j)*(t_i)))
                y_out = y_out + (a[j][0])*(r_i**(-j))*(sin((-j)*(t_i)))

            out_i_carte = torch.cat((x_out, y_out), dim=1)           

            list_concat.append(out_i_carte)
        out = torch.stack(list_concat).reshape(-1, 2)
        return out


    def exterior_disk_inverse(self, data, n):
        data = self.carte_to_polar(data)

        N = data.shape[0]
        list_concat = []

        for i in range(N):
            r_i = data[[i], [0]].reshape(-1, 1)
            t_i = data[[i], [1]].reshape(-1, 1)
            x_out = (r_i)**(-n)*cos((n*t_i))
            y_out = -(r_i)**(-n)*sin((n*t_i))

            out_i_carte = torch.cat((x_out, y_out), dim=1)

            list_concat.append(out_i_carte)

        out = torch.stack(list_concat).reshape(-1, 2)
        return out


    def conformal_map_normal(self, data):
        #### Then return the normal vector for post-data(image of psi.)

        a = self.conformal_coefficients()

        data = self.carte_to_polar(data)
        N = data.shape[0]
        list_concat = []
        for i in range(N):
            r_i = data[[i], [0]].reshape(-1, 1)
            t_i = data[[i], [1]].reshape(-1, 1)

            x_out = (r_i)*(cos(t_i))
            y_out = (r_i)*(sin(t_i))

            for j in range(1, self.Conformal_order+1):
            
                x_out = x_out + (-j)*(a[j][0])*(r_i**(-j))*(cos((-j)*(t_i)))
                y_out = y_out + (-j)*(a[j][0])*(r_i**(-j))*(sin((-j)*(t_i)))

            out_i_carte = torch.cat((x_out, y_out), dim=1) 

            list_concat.append(out_i_carte)

        out = torch.stack(list_concat).reshape(-1, 2)

        return out

    def generate_normal_and_trace(self, boundary_disk):
        boundary_psi = self.conformal_map(boundary_disk)
        normal_psi = self.conformal_map_normal(boundary_disk)
        ## Normalize the normal vector.
        normal_psi = F.normalize(normal_psi, p=2, dim=1)

        trace_int_psi = boundary_psi - (self.eps_bd)*(normal_psi)
        trace_ext_psi = boundary_psi + (self.eps_bd)*(normal_psi)

        return normal_psi, trace_int_psi, trace_ext_psi


    def generate_data_interior(self, R_bd, S_int_angular, S_int_radial):
        
        if self.fixed == 1:
            tspace = np.linspace(0, 2*pi, S_int_angular+1).astype(np.float32)
        elif self.fixed == 0:
            tspace = np.random.uniform(0, 2*pi, S_int_angular+1).astype(np.float32)
        tspace.sort()

        rspace = np.array([R_bd]).astype(np.float32)

        bd_for_int_disk_polar = self.meshgrid_tensor(rspace, tspace)
        bd_for_int_disk = self.polar_to_carte(bd_for_int_disk_polar)

        bd_for_int_psi = self.conformal_map(bd_for_int_disk)

        if self.fixed == 1:
            ruler = np.linspace(0, 1, S_int_radial+1).astype(np.float32)
        elif self.fixed == 0:
            ruler = np.random.uniform(0, 1, S_int_radial+1).astype(np.float32)

        ruler_tensor = torch.tensor(ruler).reshape(-1, 1)

        N_bd = bd_for_int_psi.shape[0]

        list_concat = []

        for idx in range(N_bd):

            interior_idx = bd_for_int_psi[[idx], :]
            interior_psi_idx = ruler_tensor * interior_idx

            list_concat.append(interior_psi_idx)

        interior_psi = torch.stack(list_concat).reshape(-1, 2)

        return interior_psi


    def generate_data_bd_ext(self, boundary_disk, exterior_disk):

        boundary_psi = self.conformal_map(boundary_disk)
        exterior_psi = self.conformal_map(exterior_disk)

        normal_psi, trace_int_psi, trace_ext_psi = self.generate_normal_and_trace(boundary_disk)

        return boundary_psi, exterior_psi, normal_psi, trace_int_psi, trace_ext_psi


