import h5py
import xarray as xr
from datetime import datetime, timedelta
import pandas as pd
import numpy as np
import sys
from util_data.config import cfg
sys.path.append("/here/is/code/M2F-PINN")

from typing import Tuple, List
import torch
import random
from torch.utils import data
from torchvision import transforms as T
import os
import tables

class DataPrefetcher():
    def __init__(self, loader):
        self.loader = loader
        self.dataiter = iter(loader)
        self.length = len(self.loader)
        self.stream = torch.cuda.Stream()
        self.__preload__()

    def __preload__(self):
        try:
            self.input, self.input_surface, self.target, self.target_surface, self.periods = next(self.dataiter)
        except StopIteration:
            self.dataiter = iter(self.loader)
            self.input, self.input_surface, self.target, self.target_surface, self.periods = next(self.dataiter)

        with torch.cuda.stream(self.stream):
            self.target = self.target.cuda(non_blocking=True)
            self.target_surface = self.target_surface.cuda(non_blocking=True)
            self.input = self.input.cuda(non_blocking=True)
            self.input_surface = self.input_surface.cuda(non_blocking=True)
            # self.periods = self.periods.cuda(non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        self.__preload__()
        return self.input, self.input_surface, self.target, self.target_surface, self.periods

    def __len__(self):
        """Return the number of images."""
        return self.length


class NetCDFDataset(data.Dataset):
    """Dataset class for the era5 upper and surface variables."""

    def __init__(self,
                 nc_path='/here/your/data',
                 data_transform=None,
                 seed=1234,
                 training=True,
                 validation=False,
                 startDate='20150101',
                 endDate='20150102',
                 freq='H',
                 horizon=5):
        """Initialize."""
        self.horizon = horizon
        self.nc_path = nc_path
        """
        To do
        if start and end is valid date, if the date can be found in the downloaded files, length >= 0

        """
        # Prepare the datetime objects for training, validation, and test
        self.training = training
        self.validation = validation
        self.data_transform = data_transform

        if training:
            self.keys = list(pd.date_range(start=startDate, end=endDate, freq=freq))
        elif validation:
            self.keys = list(pd.date_range(start=startDate, end=endDate, freq=freq))
            # self.keys = (list(set(self.keys)))

        else:
            self.keys = list(pd.date_range(start=startDate, end=endDate, freq=freq))
        self.length = len(self.keys) - horizon // 12 - 1

        random.seed(seed)

    def __getitem__(self, index):
        key = self.keys[index]
        key_str = key.strftime('%Y%m%d%H')
        # month_str = key.strftime('%Y%m') 
        day_str = key.strftime('%Y%m%d')
        # file_path = os.path.join(self.nc_path, f"data_{month_str}.h5")
        file_path = os.path.join(self.nc_path, f"data_{day_str}.h5")
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"HDF5 file not found: {file_path}")

        with h5py.File(file_path, 'r', swmr=True, rdcc_nbytes=5*13*721*1440*4*8,rdcc_w0=1,rdcc_nslots=1e7) as hdf5_file:
            if key_str not in hdf5_file:
                raise KeyError(f"Time group '{key_str}' not found in HDF5 file: {file_path}")

            group = hdf5_file[key_str]
            input = group['input'][:]
            input_surface = group['input_surface'][:]
            target = group['target'][:]
            target_surface = group['target_surface'][:]
            periods = tuple(group.attrs['periods'].astype(str))

        return input,input_surface,target,target_surface,periods

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__


def weatherStatistics_output(filepath="/here/your/data/aux_data", device="cpu"):
    """
    :return:1, 5, 13, 1, 1
    """
    surface_mean = np.load(os.path.join(filepath, "surface_mean.npy")).astype(np.float32)
    surface_mean = surface_mean.reshape(4)
    surface_std = np.load(os.path.join(filepath, "surface_std.npy")).astype(np.float32)
    surface_std = surface_std.reshape(4)
    surface_mean = torch.from_numpy(surface_mean)
    surface_std = torch.from_numpy(surface_std)
    surface_mean = surface_mean.view(1, 4, 1, 1)
    surface_std = surface_std.view(1, 4, 1, 1)

    upper_mean = np.load(os.path.join(filepath, "upper_mean.npy")).astype(np.float32)  # (13,1,1,5)
    upper_mean = upper_mean.transpose(2, 3, 4, 1, 0)
    upper_mean = upper_mean.reshape(13, 1, 1, 5)
    upper_mean = upper_mean[::-1, :, :, :].copy()
    upper_mean = np.transpose(upper_mean, (1, 3, 0, 2))  # (1,5,13, 1)
    upper_mean = torch.from_numpy(upper_mean)

    upper_std = np.load(os.path.join(filepath, "upper_std.npy")).astype(np.float32)
    upper_std = upper_std.transpose(2, 3, 4, 1, 0)
    upper_std = upper_std.reshape(13, 1, 1, 5)
    upper_std = upper_std[::-1, :, :, :].copy()
    upper_std = np.transpose(upper_std, (1, 3, 0, 2))
    upper_std = torch.from_numpy(upper_std)

    return surface_mean.to(device), surface_std.to(device), upper_mean[..., None].to(device), upper_std[..., None].to(
        device)


def weatherStatistics_input(filepath="/here/your/data/aux_data", device="cpu"):
    """
    :return:13, 1, 1, 5
    """
    surface_mean = np.load(os.path.join(filepath, "surface_mean.npy")).astype(np.float32)
    surface_mean = surface_mean.reshape(4)
    surface_std = np.load(os.path.join(filepath, "surface_std.npy")).astype(np.float32)
    surface_std = surface_std.reshape(4)
    surface_mean = torch.from_numpy(surface_mean)
    surface_std = torch.from_numpy(surface_std)

    upper_mean = np.load(os.path.join(filepath, "upper_mean.npy")).astype(np.float32)
    upper_std = np.load(os.path.join(filepath, "upper_std.npy")).astype(np.float32)
    upper_mean = upper_mean.transpose(2, 3, 4, 1, 0)
    upper_mean = upper_mean.reshape(13, 1, 1, 5)
    upper_std = upper_std.transpose(2, 3, 4, 1, 0)
    upper_std = upper_std.reshape(13, 1, 1, 5)
    upper_mean = torch.from_numpy(upper_mean)
    upper_std = torch.from_numpy(upper_std)

    return surface_mean.to(device), surface_std.to(device), upper_mean.to(device), upper_std.to(device)

def LoadConstantMask(filepath='/here/your/data/constant_masks', device="cpu"):
    land_mask = np.load(os.path.join(filepath, "land_mask.npy")).astype(np.float32)
    soil_type = np.load(os.path.join(filepath, "soil_type.npy")).astype(np.float32)
    topography = np.load(os.path.join(filepath, "topography.npy")).astype(np.float32)
    land_mask = torch.from_numpy(land_mask)  # ([721, 1440])
    soil_type = torch.from_numpy(soil_type)  # ([721, 1440])
    topography = torch.from_numpy(topography)  # ([721, 1440])

    return land_mask[None, None, ...].to(device), soil_type[None, None, ...].to(device), topography[None, None, ...].to(
        device)  # torch.Size([1, 1, 721, 1440])

def LoadConstantMask3(filepath="/here/your/data/aux_data", device="cpu"):
    mask = np.load(os.path.join(filepath, "constantMaks3.npy")).astype(np.float32)
    mask = torch.from_numpy(mask)
    mask = mask.repeat(2, 1, 1, 1)
    return mask.to(device)


def computeStatistics(train_loader):
    # prepare for the statistics
    weather_surface_mean, weather_surface_std = torch.zeros(1, 4, 1, 1), torch.zeros(1, 4, 1, 1)
    weather_mean, weather_std = torch.zeros(1, 5, 13, 1, 1), torch.zeros(1, 5, 13, 1, 1)
    for id, train_data in enumerate(train_loader, 0):
        input, input_surface, _, _, _ = train_data
        weather_surface_mean += torch.mean(input_surface, dim=(-1, -2), keepdim=True)
        weather_surface_std += torch.std(input_surface, dim=(-1, -2), keepdim=True)
        weather_mean += torch.mean(input, dim=(-1, -2), keepdim=True)
        weather_std += torch.std(input, dim=(-1, -2), keepdim=True)  # (1,5,13,)
    weather_surface_mean, weather_surface_std, weather_mean, weather_std = \
        weather_surface_mean / len(train_loader), weather_surface_std / len(train_loader), weather_mean / len(
            train_loader), weather_std / len(train_loader)

    return weather_surface_mean, weather_surface_std, weather_mean, weather_std

def loadConstMask_h(filepath="/here/your/data/aux_data", device="cpu"):
    mask_h = np.load(os.path.join(filepath, "Constant_17_output_0.npy")).astype(np.float32)
    mask_h = torch.from_numpy(mask_h)
    mask_h = mask_h.repeat(2, 1, 1, 1,1,1)
    return mask_h.to(device)


def loadVariableWeights(device="cpu"):
    upper_weights = torch.FloatTensor(cfg.PG.TRAIN.UPPER_WEIGHTS).unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4)
    surface_weights = torch.FloatTensor(cfg.PG.TRAIN.SURFACE_WEIGHTS).unsqueeze(0).unsqueeze(2).unsqueeze(3)
    return upper_weights.to(device), surface_weights.to(device)


def loadAllConstants(device):
    constants = dict()
    constants['weather_statistics'] = weatherStatistics_input(
        device=device)  # height has inversed shape, order is reversed in model
    constants['weather_statistics_last'] = weatherStatistics_output(device=device)
    constants['constant_maps'] = LoadConstantMask3(device=device) #not able to be equal
    constants['variable_weights'] = loadVariableWeights(device=device)
    constants['const_h'] = loadConstMask_h(device=device)

    return constants

def normData(upper, surface, statistics):
    surface_mean, surface_std, upper_mean, upper_std = (
        statistics[0], statistics[1], statistics[2], statistics[3])

    upper = (upper - upper_mean) / upper_std
    surface = (surface - surface_mean) / surface_std
    return upper, surface


def normBackData(upper, surface, statistics):
    surface_mean, surface_std, upper_mean, upper_std = (
        statistics[0], statistics[1], statistics[2], statistics[3])
    upper = upper * upper_std + upper_mean
    surface = surface * surface_std + surface_mean

    return upper, surface

if __name__ == "__main__":
    # dataset_path ='/home/code/data_storage_home/data/aquampinn'
    # means, std = LoadStatic(os.path.join(dataset_path, 'aux_data'))
    # print(means.shape) #(1, 21, 1, 1)
    a, b, c, d = weatherStatistics_input()
    print(a.shape)

