import h5py
import os
from tqdm import tqdm
from datetime import datetime, timedelta
import xarray as xr
import numpy as np
import pandas as pd
import random
from torch.utils import data
import datetime as dt
from multiprocessing import Pool, cpu_count


class NetCDFDataset(data.Dataset):
    """Dataset class for the era5 upper and surface variables."""

    def __init__(self,
                 nc_path='/here/are/input/data/path',
                 data_transform=None,
                 seed=1234,
                 training=True,
                 validation=False,
                 startDate='20150101',
                 endDate='20150102',
                 freq='H',
                 horizon=5):
        """Initialize."""
        self.horizon = horizon
        self.nc_path = nc_path

        # Prepare the datetime objects for training, validation, and test
        self.training = training
        self.validation = validation
        self.data_transform = data_transform

        if training:
            self.keys = list(pd.date_range(start=startDate, end=endDate, freq=freq))

        elif validation:
            self.keys = list(pd.date_range(start=startDate, end=endDate, freq=freq))
            # self.keys = (list(set(self.keys)))

        else:
            self.keys = list(pd.date_range(start=startDate, end=endDate, freq=freq))
        self.length = len(self.keys) - horizon // 12 - 1

        random.seed(seed)

    def nctonumpy(self, dataset_upper, dataset_surface):
        """
        Input
            xr.Dataset upper, surface
        Return
            numpy array upper, surface
        """

        upper_z = dataset_upper['thetao'].values.astype(np.float32)  # (13,721,1440)
        upper_q = dataset_upper['so'].values.astype(np.float32)
        upper_t = dataset_upper['uo'].values.astype(np.float32)
        upper_u = dataset_upper['vo'].values.astype(np.float32)
        upper_v = dataset_upper['zos'].values.astype(np.float32)
        upper = np.concatenate((upper_z[np.newaxis, ...], upper_q[np.newaxis, ...], upper_t[np.newaxis, ...],
                                upper_u[np.newaxis, ...], upper_v[np.newaxis, ...]), axis=0)
        assert upper.shape == (5, 13, 721, 1440)
        # levels in descending order, require new memery space
        upper = upper[:, ::-1, :, :].copy()

        surface_mslp = dataset_surface['thetao'].values.astype(np.float32)  # (721,1440)
        surface_u10 = dataset_surface['so'].values.astype(np.float32)
        surface_v10 = dataset_surface['uo'].values.astype(np.float32)
        surface_t2m = dataset_surface['vo'].values.astype(np.float32)
        surface = np.concatenate((surface_mslp[np.newaxis, ...], surface_u10[np.newaxis, ...],
                                  surface_v10[np.newaxis, ...], surface_t2m[np.newaxis, ...]), axis=0)
        assert surface.shape == (4, 721, 1440)

        upper = np.nan_to_num(upper)
        surface = np.nan_to_num(surface)

        return upper, surface

    def LoadData(self, key):
        """
        Input
            key: datetime object, input time
        Return
            input: numpy
            input_surface: numpy
            target: numpy label
            target_surface: numpy label
            (start_time_str, end_time_str): string, datetime(target time - input time) = horizon
        """
        # start_time datetime obj
        start_time = key
        # convert datetime obj to string for matching file name and return key
        start_time_str = datetime.strftime(key, '%Y%m%d%H')

        # target time = start time + horizon
        end_time = key + timedelta(hours=self.horizon)
        end_time_str = end_time.strftime('%Y%m%d%H')

        # Prepare the input_surface dataset
        # print(start_time_str[0:6])
        input_surface_dataset = xr.open_dataset(
            os.path.join(self.nc_path, 'surface', 'surface_{}.nc'.format(start_time_str[0:6])))  # 201501
        if 'expver' in input_surface_dataset.keys():
            input_surface_dataset = input_surface_dataset.sel(valid_time=start_time)#(time=start_time, expver=5)
        else:
            input_surface_dataset = input_surface_dataset.sel(time=start_time)

        # Prepare the input_upper dataset
        input_upper_dataset = xr.open_dataset(
            os.path.join(self.nc_path, 'upper_result', 'upper_{}.nc'.format(start_time_str[0:8])))
        if 'expver' in input_upper_dataset.keys():
            input_upper_dataset = input_upper_dataset.sel(valid_time=start_time)#, expver=5)
        else:
            input_upper_dataset = input_upper_dataset.sel(time=start_time)
        # make sure upper and surface variables are at the same time
        assert input_surface_dataset['time'] == input_upper_dataset['time']
        # assert input_surface_dataset['valid_time'] == input_upper_dataset['valid_time']
        # input dataset to input numpy
        input, input_surface = self.nctonumpy(input_upper_dataset, input_surface_dataset)

        # Prepare the target_surface dataset
        target_surface_dataset = xr.open_dataset(
            os.path.join(self.nc_path, 'surface', 'surface_{}.nc'.format(end_time_str[0:6])))  # 201501
        if 'expver' in input_surface_dataset.keys():
            target_surface_dataset = target_surface_dataset.sel(valid_time=end_time)#, expver=5)
        else:
            target_surface_dataset = target_surface_dataset.sel(time=end_time)
        # Prepare the target upper dataset
        target_upper_dataset = xr.open_dataset(
            os.path.join(self.nc_path, 'upper_result', 'upper_{}.nc'.format(end_time_str[0:8])))
        if 'expver' in target_upper_dataset.keys():
            target_upper_dataset = target_upper_dataset.sel(valid_time=end_time)#, expver=5)
        else:
            target_upper_dataset = target_upper_dataset.sel(time=end_time)
        # make sure the target upper and surface variables are at the same time
        assert target_upper_dataset['time'] == target_surface_dataset['time']
        # assert target_upper_dataset['valid_time'] == target_surface_dataset['valid_time']
        # target dataset to target numpy
        target, target_surface = self.nctonumpy(target_upper_dataset, target_surface_dataset)

        return input, input_surface, target, target_surface, (start_time_str, end_time_str)

    def __getitem__(self, index):
        """Return input frames, target frames, and its corresponding time steps."""
        # if self.training:
        if False:
            iii = self.keys[index]
            input, input_surface, target, target_surface, periods = self.LoadData(iii)

            if self.data_transform is not None:
                input = self.data_transform(input)
                input_surface = self.data_transform(input_surface)

        else:
            iii = self.keys[index]
            input, input_surface, target, target_surface, periods = self.LoadData(iii)

        # iii = self.keys[index]
        # input, input_surface, target, target_surface, periods = self.LoadData(iii)

        return input, input_surface, target, target_surface, periods

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__


class NetCDFPreprocessor:
    def __init__(self, dataset, hdf5_dir):
        """
        Prepares the dataset for conversion to HDF5 format.

        Args:
            dataset: Instance of the NetCDFDataset class.
            hdf5_path: Path to store the HDF5 file.
        """
        self.keys = dataset.keys
        self.dataset = dataset
        self.hdf5_dir = hdf5_dir
        os.makedirs(hdf5_dir, exist_ok=True)

    def create_hdf5(self):
        daily_data = {}
        # 按月分组数据
        for key in self.keys:
            day = key.strftime('%Y%m%d')
            if day not in daily_data:
                daily_data[day] = []
            daily_data[day].append(key)

        with Pool(processes=4) as pool:
            pool.map(self._process_day, daily_data.items())

    def _process_day(self, day_data):
        day, keys_in_day = day_data
        file_path = os.path.join(self.hdf5_dir, f"data_{day}.h5")
        # print(f"Processing day: {day}, output: {file_path}")

        with h5py.File(file_path, 'w') as hdf5_file:
            for key in tqdm(keys_in_day, desc=f"Processing data for {day}"):#keys_in_day:
                input, input_surface, target, target_surface, periods = self.dataset.LoadData(key)

                key_str = key.strftime('%Y%m%d%H')
                group = hdf5_file.create_group(key_str)
                group.create_dataset('input', data=input, compression='gzip', chunks=(5, 13, 721, 1440))
                group.create_dataset('input_surface', data=input_surface, compression='gzip', chunks=(4, 721, 1440))
                group.create_dataset('target', data=target, compression='gzip', chunks=(5, 13, 721, 1440))
                group.create_dataset('target_surface', data=target_surface, compression='gzip', chunks=(4, 721, 1440))#, compression='gzip'
                group.attrs['periods'] = periods
        print("Daily HDF5 files have been created.")


# Example usage
if __name__ == "__main__":
    # Create an instance of the dataset
    dataset = NetCDFDataset(
        nc_path='/here/are/input/data/path',
        data_transform=None,
        training=True,
        validation=False,
        startDate='20070101120000',
        endDate='20081231120000',
        freq='24H',
        horizon=24,
    )

    # Path to save the HDF5 file
    hdf5_path = '/here/are/output/data/path'

    # Preprocess NetCDF files and store them as HDF5
    preprocessor = NetCDFPreprocessor(dataset, hdf5_path)
    preprocessor.create_hdf5()
