from torch_geometric.data import Dataset
import torch_geometric.transforms as T
import torch
from zipfile import ZipFile
import io
import numpy as np

class CFDGraphsDataset(Dataset):
    """ Dataset comprised of 2D OpenFOAM CFD simulations """

    def __init__(self, zip_path: str, sdf_input:bool=True, rd_in_polar_coords:bool=False, random_masking: bool=False,
                 zero_augmentation:bool=False, farfield_mag_aoa:bool=True, airfoil_coverage=1, transform=None, pre_transform=None):
        """
        Args:
            zip_path: path to the zip file containing the cfd simulations parsed to graphs
        """
        assert 0 < airfoil_coverage <= 1, "airfoil_coverage must be between 0 and 1"
        super().__init__(None, transform, pre_transform)
        self.__zip_path = zip_path
        with ZipFile(zip_path, 'r') as zf:
            self.__num_graphs = len(zf.namelist())

        self.sdf_input = sdf_input
        self.random_masking = random_masking
        self.rd_in_polar_coords = rd_in_polar_coords
        self.zero_augmentation = zero_augmentation
        self.farfield_mag_aoa = farfield_mag_aoa
        self.airfoil_coverage = airfoil_coverage

    def __getitem__(self, idx):
        # read the zip file and select the data to load in by index
        with ZipFile(self.__zip_path, 'r') as zf:
            with zf.open(zf.namelist()[idx]) as item:
                stream = io.BytesIO(item.read())
                data = torch.load(stream, weights_only=False)
        
        # make sure features are float
        data.x = data.x.float()
        data.edge_attr = data.edge_attr.float()
        data.pos = data.pos.float()

        # if globals should have magnitude and angle of attack instead of x and y velocities
        if self.farfield_mag_aoa:
            Uinf = (data.globals[0]**2+ data.globals[1]**2)**0.5
            alpha = torch.rad2deg(torch.arctan2(data.globals[1], data.globals[0]))
            data.globals[0] = Uinf
            data.globals[1] = alpha
            
        # set target globals
        data.globals_y = data.globals.float().detach().clone().unsqueeze(0)
        # Cn_est = self.compute_force_coeff(data.x, data.edge_index, data.edge_attr, data.node_type)
                
        # save the relative edge vectors and the face surfaces
        data.edge_rd = data.edge_attr[:, 0:2]
        data.edge_s = data.edge_attr[:, -1:]

        # replaces the cartesian relative distance vector between nodes with polar angle if enabled
        if self.rd_in_polar_coords:
            polar_transf = T.Polar(norm=False, cat=True)
            data = polar_transf(data)
            data.edge_attr = data.edge_attr[:, [-1, -2, -3]].float()
            data.edge_attr_labels = ['theta', 'edge_length', 'face_surface']

        # apply 5% zero augmentation chance if activated
        if self.zero_augmentation:
            augment = np.random.choice([False, True], 1, p=[0.95, 0.05])
            if augment:
                data.x = torch.zeros_like(data.x)
                data.globals = torch.zeros_like(data.globals)
                data.globals_y = torch.zeros_like(data.globals_y)
        
        # get the pressure values for the measured airfoil nodes
        if self.airfoil_coverage < 1:
            airf_nodes = data.node_type == 1.0
            measured_airf_nodes = airf_nodes * (data.pos[:, 0] < self.airfoil_coverage) #chord len = 1
            unmeasured_airf_nodes = airf_nodes * (data.pos[:, 0] >= self.airfoil_coverage)
            airfoil_pressure = data.x[measured_airf_nodes, 0]
        else:
            airfoil_pressure = data.x[data.node_type == 1.0, 0]
        
        # normalize the pressure and velocity, and save the normalization values
        airfoil_pressure_mean = airfoil_pressure.mean(dim=0)
        airfoil_pressure_std = airfoil_pressure.std(dim=0)
        
        # estimate initial globals and overwrite for inputting
        Uinf_est = torch.sqrt(2*airfoil_pressure.max()/1.2250) # assume incompressible flow and use Bernoulli's principle
        # data.globals = torch.tensor([Uinf_est, Cn_est]).unsqueeze(0)
        data.globals = torch.tensor([Uinf_est]).unsqueeze(0)
        
        # normalize the node features
        data.x[:,0] = self.__normalize(data.x[:,0], subtract=airfoil_pressure_mean, divide=airfoil_pressure_std)
        data.x[:,1] = self.__normalize(data.x[:,1], subtract=0, divide=Uinf_est)
        data.x[:,2] = self.__normalize(data.x[:,2], subtract=0, divide=Uinf_est)
        data.node_norm_vals = torch.tensor([[airfoil_pressure_mean, airfoil_pressure_std], [0, Uinf_est], [0, Uinf_est]])
        
        # copy the node target values [pressure, x-velocity, y-velocity] without the signed distance function [idx 3]
        data.y = data.x[:, :3].detach().clone()

        if self.sdf_input:
            data.x = data.x[:, [0,3]]
            data.input_node_feat_labels = ['pressure', 'sdf']
        else: # keep only the pressure as the input
            data.x = data.x[:, 0:1]
            data.input_node_feat_labels = ['pressure']

        # build the mask for the fluid nodes
        if self.random_masking:
            masked_fluid_proportion = np.random.uniform(0.7, 1.0)
        else:
            masked_fluid_proportion = 1.0
        
        # mask the fluid nodes
        mask = torch.tensor(np.random.choice([0, 1], size=data.num_nodes, p=[1-masked_fluid_proportion, masked_fluid_proportion])) > 0
        self.__mask_nodes(data.x[:, 0], mask * (data.node_type == 0.0), mask_value=float('nan')) 
        
        # mask airfoil the back half of the arfoil nodes if the 0-100% coverage is not 100%
        if self.airfoil_coverage < 1:
            self.__mask_nodes(data.x[:, 0], unmeasured_airf_nodes, mask_value=float('nan'))
    
        # add node type to node features passed as input to network
        data.x = torch.cat((data.x, torch.nn.functional.one_hot(data.node_type)), dim=1)

        # get the known feature mask
        data.known_feature_mask = ~torch.isnan(data.x)

        return data

    def __mask_nodes(self, x, mask_tensor, mask_value=float('nan')):
        x[mask_tensor] = torch.tensor(mask_value)
        
    @property
    def num_glob_features(self) -> int:
        r"""Returns the number of global features in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        return data.globals.shape[1]

    @property
    def num_glob_output_features(self) -> int:
        r"""Returns the number of global output features in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        return data.globals_y.shape[1]

    @property
    def num_node_output_features(self) -> int:
        r"""Returns the number of node output features in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        return data.y.shape[1]

    def get_data_dims_dict(self) -> dict:
        r"""Returns a dictionary with the number of features for each type of data."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        return {'node_feature_dim': data.x.shape[1], 'edge_feature_dim': data.edge_attr.shape[1], 'glob_feature_dim': data.globals.shape[1], 'node_out_dim': data.y.shape[1], 'glob_out_dim': data.globals_y.shape[1]}


    def __len__(self):
        return self.__num_graphs
    
    def __normalize(self, x:torch.tensor, subtract=None, divide=None):
        if subtract is None:
            # per feature normalization
            subtract = x.mean(dim=0)
        if divide is None:
            divide =  x.std(dim=0)

        # handling 0 division for the std norm
        divide[divide==0.0] = 1.0

        return  (x - subtract) / divide

    def compute_force_coeff(self, x, edge_index, edge_attr, node_type):
        pressure_distrib = x[node_type==1.0, 0]
        row, col = edge_index
        airfoil_sender_edges = torch.isin(row, (node_type==1.0).nonzero())
        fluid_receiver_edges = torch.isin(col, (node_type==0.0).nonzero())
        sorted_node_to_fluid_edges = torch.sort((airfoil_sender_edges * fluid_receiver_edges).nonzero()).values
        costheta = -edge_attr[sorted_node_to_fluid_edges, 1].squeeze()
        ds = edge_attr[sorted_node_to_fluid_edges,3].squeeze()
        N = torch.sum(pressure_distrib*ds*costheta)
        Cn = N/pressure_distrib.max()
        return torch.nan_to_num(Cn, nan=0.0, posinf=0.0, neginf=0.0)

if __name__ == '__main__':
    import sys
    sys.path.append('../')
    # from utils import plot_mesh
    from plotter import plot_mesh
    import matplotlib.pyplot as plt
        
    dataset = CFDGraphsDataset(r'train_dataset.zip',
                               random_masking=False, zero_augmentation=False, sdf_input=True,
                               airfoil_coverage=0.1)
    print(len(dataset))
    print(dataset.num_node_features)
    print(dataset.num_node_output_features)
    print(dataset.num_edge_features)
    print(dataset.num_glob_features)
    print(dataset[0].node_feat_labels)
    print(dataset[0].globals)
    print(dataset[0].globals_y)

    # print 10 airfoil nodes x
    # print(g.x[g.node_type==1.0][:10])
    
    # print 10 fluid nodes x
    # print(g.x[g.node_type==0.0][:10])

    # for plot_type in ['pressure', 'velocity_mag', 'velocity_x', 'velocity_y']:
    for plot_type in ['surface_pressure']:
        for i in range(2):
            plot_mesh(dataset[i], plot_type, plot_predicted=True, show=False, add_farfield_info=True)
    plt.show()