import os.path as osp
from typing import Union
import scipy.io as sio
import numpy as np

from h5py import File

import torch
from torch.utils.data import Dataset, DataLoader

from utils.normalizer import UnitGaussianNormalizer, GaussianNormalizer


class NavierStokes3DDataset:
    def __init__(self, data_path, sample_factor=[1, 1, 1],
                 train_batchsize=10, eval_batchsize=10, 
                 train_ratio=0.8, valid_ratio=0.1, test_ratio=0.1, 
                 normalize=True, normalizer_type='PGN', prop='Vx', **kwargs):
        self.__file__ = osp.abspath(__file__)
        
        self.load_data(data_path=data_path, train_ratio=train_ratio, valid_ratio=valid_ratio, test_ratio=test_ratio, normalize=normalize, normalizer_type=normalizer_type, prop=prop, sample_factor=sample_factor)

        self.train_loader = DataLoader(self.train_dataset, batch_size=train_batchsize, shuffle=True)
        self.valid_loader = DataLoader(self.valid_dataset, batch_size=eval_batchsize, shuffle=False)
        self.test_loader = DataLoader(self.test_dataset, batch_size=eval_batchsize, shuffle=False)

    def load_data(self, data_path, sample_factor,
                  train_ratio, valid_ratio, test_ratio, 
                  normalize, normalizer_type, prop='Vx', **kwargs):
        process_path = data_path.split('.')[0] + '_' + prop + '_processed.pt'
        if osp.exists(process_path):
            print('Loading processed data from ', process_path)
            train_x, train_y, valid_x, valid_y, test_x, test_y, x_normalizer, y_normalizer = torch.load(process_path)
        else:
            print('Processing data...')
            raw_data = File(data_path, 'r')
            # v_x = torch.tensor(np.asarray(raw_data["Vx"]), dtype=torch.float32).unsqueeze(-1)
            # v_y = torch.tensor(np.asarray(raw_data["Vy"]), dtype=torch.float32).unsqueeze(-1)
            # v_z = torch.tensor(np.asarray(raw_data["Vz"]), dtype=torch.float32).unsqueeze(-1)
            # density = torch.tensor(np.asarray(raw_data["density"]), dtype=torch.float32).unsqueeze(-1)
            # pressure = torch.tensor(np.asarray(raw_data["pressure"]), dtype=torch.float32).unsqueeze(-1)
            # data = torch.cat([v_x, v_y, v_z, density, pressure], dim=-1)
            if prop == 'Vx':
                data = torch.tensor(np.asarray(raw_data["Vx"]), dtype=torch.float32).unsqueeze(-1)
            elif prop == 'Vy':
                data = torch.tensor(np.asarray(raw_data["Vy"]), dtype=torch.float32).unsqueeze(-1)
            elif prop == 'Vz':
                data = torch.tensor(np.asarray(raw_data["Vz"]), dtype=torch.float32).unsqueeze(-1)
            elif prop == 'density':
                data = torch.tensor(np.asarray(raw_data["density"]), dtype=torch.float32).unsqueeze(-1)
            elif prop == 'pressure':
                data = torch.tensor(np.asarray(raw_data["pressure"]), dtype=torch.float32).unsqueeze(-1)
            else:
                raise ValueError("Invalid property specified. Choose from 'Vx', 'Vy', 'Vz', 'density', or 'pressure'.")
            
            grid_x = torch.tensor(np.asarray(raw_data["x-coordinate"]), dtype=torch.float32).unsqueeze(-1)
            grid_y = torch.tensor(np.asarray(raw_data["y-coordinate"]), dtype=torch.float32).unsqueeze(-1)
            grid_z = torch.tensor(np.asarray(raw_data["z-coordinate"]), dtype=torch.float32).unsqueeze(-1)
            grid_t = torch.tensor(np.asarray(raw_data["t-coordinate"]), dtype=torch.float32).unsqueeze(-1)
            
            grid_x = torch.tensor(np.asarray(raw_data["x-coordinate"]), dtype=torch.float32)
            grid_y = torch.tensor(np.asarray(raw_data["y-coordinate"]), dtype=torch.float32)
            grid_z = torch.tensor(np.asarray(raw_data["z-coordinate"]), dtype=torch.float32)
            grid_t = torch.tensor(np.asarray(raw_data["t-coordinate"]), dtype=torch.float32)

            x_size = grid_x.shape[0]
            y_size = grid_y.shape[0]
            z_size = grid_z.shape[0]
            t_size = grid_t.shape[0]

            grid_x = grid_x.reshape(1, -1, 1, 1, 1).repeat(1, 1, y_size, z_size, 1)
            grid_y = grid_y.reshape(1, 1, -1, 1, 1).repeat(1, x_size, 1, z_size, 1)
            grid_z = grid_z.reshape(1, 1, 1, -1, 1).repeat(1, x_size, y_size, 1, 1)
            grid_t = grid_t.reshape(1, 1, 1, 1, -1).repeat(1, x_size, y_size, z_size, 1)

            grid = torch.cat([grid_x, grid_y, grid_z], dim=-1)
            
            data_size = 50
            train_idx = int(data_size * train_ratio)
            valid_idx = int(data_size * (train_ratio + valid_ratio))
            test_idx = int(data_size * (train_ratio + valid_ratio + test_ratio))
            
            train_x, train_y, x_normalizer, y_normalizer = self.pre_process(data[:train_idx], grid=grid, mode='train', sample_factor=sample_factor, normalize=normalize, normalizer_type=normalizer_type)
            valid_x, valid_y, _, _ = self.pre_process(data[train_idx:valid_idx], grid=grid, mode='valid', sample_factor=sample_factor,
                                          normalize=normalize, x_normalizer=x_normalizer, y_normalizer=y_normalizer)
            test_x, test_y, _, _ = self.pre_process(data[valid_idx:test_idx], grid=grid, mode='test', sample_factor=sample_factor,
                                         normalize=normalize, x_normalizer=x_normalizer, y_normalizer=y_normalizer)
            print('Saving data...')
            torch.save((train_x, train_y, valid_x, valid_y, test_x, test_y, x_normalizer, y_normalizer), process_path)
            print('Data processed and saved to', process_path)

        self.train_dataset = NavierStokes3DBase(train_x[:6400], train_y[:6400], mode='train', x_normalizer=x_normalizer, y_normalizer=y_normalizer, sample_factor=sample_factor)
        self.valid_dataset = NavierStokes3DBase(valid_x[:600], valid_y[:600], mode='valid', x_normalizer=x_normalizer, y_normalizer=y_normalizer, sample_factor=sample_factor)
        self.test_dataset = NavierStokes3DBase(test_x[:600], test_y[:600], mode='test', x_normalizer=x_normalizer, y_normalizer=y_normalizer, sample_factor=sample_factor)

    def pre_process(self, data: torch.Tensor, grid: torch.Tensor, mode: str, normalize: bool, 
                    normalizer_type: str = 'PGN', x_normalizer: Union[UnitGaussianNormalizer, GaussianNormalizer, None] = None,
                    y_normalizer: Union[UnitGaussianNormalizer, GaussianNormalizer, None] = None,
                    sample_factor=[1, 1, 1], 
                    **kwargs):
        """
        Pre-process the data for training, validation, or testing.
        
        Args:
            data (torch.Tensor): The input data.
            grid (torch.Tensor): The grid coordinates.
            mode (str): The mode of the dataset ('train', 'valid', 'test').
            normalize (bool): Whether to normalize the data.
            normalizer_type (str): The type of normalizer to use ('PGN' or 'Gaussian').
            x_normalizer (UnitGaussianNormalizer or GaussianNormalizer, optional): Normalizer for x data.
            y_normalizer (UnitGaussianNormalizer or GaussianNormalizer, optional): Normalizer for y data.
            **kwargs: Additional keyword arguments.
        Returns:
            tuple: Processed input and target data, and normalizers if applicable.
        """
        x = data[:, :-1, :, :, :, :]
        y = data[:, 1:, :, :, :, :]
        
        x = x.flatten(start_dim=0, end_dim=1)
        y = y.flatten(start_dim=0, end_dim=1)
        
        B, H, W, D, C = x.shape
        
        if normalize:
            x = x.reshape(B, -1, C)
            y = y.reshape(B, -1, C)
            if mode == 'train':
                if normalizer_type == 'PGN':
                    x_normalizer = UnitGaussianNormalizer(x)
                    y_normalizer = UnitGaussianNormalizer(y)
                else:
                    x_normalizer = GaussianNormalizer(x)
                    y_normalizer = GaussianNormalizer(y)
                x = x_normalizer.encode(x)
                y = y_normalizer.encode(y)
            else:
                if x_normalizer is None or y_normalizer is None:
                    raise ValueError("Normalizers must be provided for validation and test modes.")
                x = x_normalizer.encode(x)
                y = y_normalizer.encode(y)
            x = x.reshape(B, H, W, D, C)
            y = y.reshape(B, H, W, D, C)
        else:
            x_normalizer = None
            y_normalizer = None

        grid = grid.repeat(B, 1, 1, 1, 1)
        x = torch.cat([grid, x], dim=-1)
        
        print('x shape:', x.shape)
        print('y shape:', y.shape)
        
        return x, y, x_normalizer, y_normalizer


class NavierStokes3DBase(Dataset):
    """
    A base class for the Navier-Stokes dataset.

    Args:
        x (list): The input data.
        y (list): The target data.
        mode (str, optional): The mode of the dataset. Defaults to 'train'.
        **kwargs: Additional keyword arguments.

    Attributes:
        mode (str): The mode of the dataset.
        x (list): The input data.
        y (list): The target data.
    """

    def __init__(self, x, y, mode='train', 
                 sample_factor=[1, 1, 1],
                 x_normalizer=None, y_normalizer=None,
                 **kwargs):
        self.mode = mode
        self.x = x[:, ::sample_factor[0], ::sample_factor[1], ::sample_factor[2], :]
        self.y = y[:, ::sample_factor[0], ::sample_factor[1], ::sample_factor[2], :]
        self.x = self.x.view(x.shape[0], -1, x.shape[-1])
        self.y = self.y.view(y.shape[0], -1, y.shape[-1])
        self.x_normalizer = x_normalizer
        self.y_normalizer = y_normalizer

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
