import torch
from torch.utils.data import Dataset
import torchvision

from petrel_client.client import Client
client = Client()
from datetime import datetime, timedelta
import calendar
import io
import numpy as np
import matplotlib.pyplot as plt
import json
import random
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import os
from os import listdir
from os.path import isfile
import h5py
class era5_latent(Dataset):

    def __init__(self, data_dir):
        super().__init__()
        self.dir = data_dir
       
        input_names = []

        #training file list
        inputs = os.path.join(self.dir)
        profiles = [f for f in listdir(inputs) if isfile(os.path.join(inputs, f))]
            
        input_names += [os.path.join(inputs, i) for i in profiles]
        #this is a list of filenames
        x = list(enumerate(input_names))
        random.shuffle(x)
        indices, input_names = zip(*x)
        
        self.input_names = input_names
        self.transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

        with h5py.File("/mnt/petrelfs/sunjingan/DiffDpo-DA/48la-norm/48vaelatent_mean_std_mode.h5", "r") as f:
            self.latent_mean = f["mean"][:]
            self.latent_std = f["std"][:]

    def get_profiles(self, index):
        input_name = self.input_names[index]
        
        with h5py.File(input_name, "r") as f:
            bkg_latent = (f["bkg_latent"][:] - self.latent_mean)/(self.latent_std+1e-8)
            ana_latent = (f["ana_latent"][:] - self.latent_mean)/(self.latent_std+1e-8)
        #print(bkg_latent.shape,ana_latent.shape)
        #bkg_latent = (bkg_latent/3+1)/2
        #ana_latent = (ana_latent/3+1)/2
        return bkg_latent.squeeze(0), ana_latent.squeeze(0)
        
    def __getitem__(self, index):
        res = self.get_profiles(index)
        return res

    def __len__(self):
        return len(self.input_names)

