from torch_geometric.data import Dataset
import torch
from zipfile import ZipFile
import io
import numpy as np

class CFDGraphsDataset(Dataset):
    """ Dataset comprised of 2D OpenFOAM CFD simulations """

    def __init__(self, zip_path: str, random_masking: bool, farfield_mag_aoa=True, one_hot_node_type=True,
                 transform=None, pre_transform=None):
        """
        Args:
            zip_path: path to the zip file containing the cfd simulations parsed to graphs
            random_masking: bool indicating if random masking should be activated
            farfield_mag_aoa: if the farfield context data should be converted to velocity magnitude and angle of attack
            one_hot_node_type: if the node type feature should be passed as a one hot vector
        """

        super().__init__(None, transform, pre_transform)
        self.__zip_path = zip_path
        with ZipFile(zip_path, 'r') as zf:
            self.__num_graphs = len(zf.namelist())

        self.random_masking = random_masking
        self.farfield_mag_aoa = farfield_mag_aoa
        self.one_hot_node_type = one_hot_node_type

    @property
    def num_glob_features(self) -> int:
        r"""Returns the number of global features in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        return data.globals.shape[1]

    @property
    def num_glob_output_features(self) -> int:
        r"""Returns the number of global output features in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        return data.globals_y.shape[1]

    @property
    def num_node_output_features(self) -> int:
        r"""Returns the number of node output features in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        return data.y.shape[1]

    def __len__(self):
        return self.__num_graphs

    def __getitem__(self, idx):
        # read the hdf5 file and select the data to load in by index
        with ZipFile(self.__zip_path, 'r') as zf:
            with zf.open(zf.namelist()[idx]) as item:
                stream = io.BytesIO(item.read())
                data = torch.load(stream)

        # delete the legacy farfield node (node used in this work)
        self.__delete_farfield_node(data)

        # make sure features are float
        data.x = data.x.float()
        data.edge_attr = data.edge_attr.float()
        data.pos = data.pos.float()

        # if globals should have magnitude and angle of attack instead of x and y velocities
        if self.farfield_mag_aoa:
            Uinf = (data.globals[0]**2+ data.globals[1]**2)**0.5
            alpha = np.rad2deg(np.arctan2(data.globals[1], data.globals[0]))
            data.globals[0] = Uinf
            data.globals[1] = alpha

        # set target globals
        data.globals_y = data.globals.float().clone().unsqueeze(0)

        # estimate initial globals and overwrite for inputting
        Uinf_est = self.compute_U_inf(data.x, data.node_type)
        Cn_est = self.compute_force_coeff(data.x, data.edge_index, data.edge_attr, data.node_type)
        data.globals = torch.tensor([Uinf_est, Cn_est]).unsqueeze(0)

        # normalize the pressure and velocity, and save the normalization values
        airfoil_pressure_mean = data.x[data.node_type==1.0,0].mean(dim=0)
        airfoil_pressure_std = data.x[data.node_type==1.0,0].std(dim=0)
        data.x[:,0] = self.__normalize(data.x[:,0], subtract=airfoil_pressure_mean, divide=airfoil_pressure_std)
        data.x[:,1] = self.__normalize(data.x[:,1], subtract=0, divide=Uinf_est)
        data.x[:,2] = self.__normalize(data.x[:,2], subtract=0, divide=Uinf_est)
        data.normalization_values = torch.tensor([[airfoil_pressure_mean, airfoil_pressure_std], [0, Uinf_est], [0, Uinf_est]])

        # normalize the edge features
        data.edge_attr = self.__normalize(data.edge_attr)

        # copy the node target values [pressure, x-velocity, y-velocity, (roughness)] and global target values
        data.y = data.x.clone()

        # keep only the pressure as the input
        data.x = data.x[:, 0:1]

        # keep only the edge flux distance as the edge feature input, but keep the others for div loss computation
        data.edge_rd = data.edge_attr[:, 0:2]
        data.edge_s = data.edge_attr[:, -1:]
        data.edge_attr = data.edge_attr[:, -1:]

        # mask the unknown fluid nodes with NaNs
        if self.random_masking:
            masked_fluid_proportion = np.random.uniform(0.7, 1.0)
        else:
            masked_fluid_proportion = 1.0
        mask = torch.tensor(np.random.choice([0, 1], size=data.num_nodes, p=[1-masked_fluid_proportion, masked_fluid_proportion])) > 0
        self.__mask_nodes(data.x, mask * (data.node_type == 0.0))

        # add node type to node features passed as input to network
        if self.one_hot_node_type:
            data.x = torch.cat((data.x, torch.nn.functional.one_hot(data.node_type)), dim=1)
        else:
            data.x = torch.cat((data.x, data.node_type.unsqueeze(1)), dim=1)

        # get the known feature mask
        data.known_feature_mask = ~torch.isnan(data.x)

        return data

    def __delete_farfield_node(self, data):
        data.x = data.x[data.node_type != 2]
        data.pos = data.pos[data.node_type != 2]
        row, col = data.edge_index
        non_farfield_sender_edges = torch.isin(row, (data.node_type != 2).nonzero())
        non_farfield_receiver_edges = torch.isin(col, (data.node_type != 2).nonzero())
        no_farfield_mask = (non_farfield_sender_edges*non_farfield_receiver_edges)
        data.edge_attr = data.edge_attr[no_farfield_mask]
        data.edge_index = data.edge_index[:, no_farfield_mask]
        data.node_type = data.node_type[data.node_type != 2]

    def __mask_nodes(self, x, mask_tensor, mask_value=float('nan')):
        x[mask_tensor] = torch.tensor(mask_value)

    def __normalize(self, x, subtract=None, divide=None):
        if subtract is None:
            # per feature normalization
            subtract = x.mean(dim=0)
        if divide is None:
            divide =  x.std(dim=0)
        return  (x - subtract) / divide

    def compute_U_inf(self, x, node_type):
        pressure_distrib = x[node_type == 1.0, 0]
        U_inf = np.sqrt(2*pressure_distrib.max()/1.2250) # assume air rho at 20degC, which was set in the CFD
        return U_inf

    def compute_force_coeff(self, x, edge_index, edge_attr, node_type):
        pressure_distrib = x[node_type==1.0, 0]
        row, col = edge_index
        airfoil_sender_edges = torch.isin(row, (node_type==1.0).nonzero())
        fluid_receiver_edges = torch.isin(col, (node_type==0.0).nonzero())
        sorted_node_to_fluid_edges = torch.sort((airfoil_sender_edges * fluid_receiver_edges).nonzero()).values
        costheta = -edge_attr[sorted_node_to_fluid_edges, 1].squeeze()
        ds = edge_attr[sorted_node_to_fluid_edges,3].squeeze()
        N = torch.sum(pressure_distrib*ds*costheta)
        Cn = N/self.compute_U_inf(x, node_type)
        return Cn