
import numpy as np
import torch.utils.data as data
import h5py
import torch

from io import BytesIO
import os.path
import os

import torch
from torch.utils.data import DataLoader
import pyarrow.parquet as pq
import pandas as pd
import boto3
import json



print("start link")
s3_client = boto3.client('s3', aws_access_key_id='K1DH3djl9BWDEMEv28Ar', aws_secret_access_key='W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L',
                         endpoint_url="http://192.168.5.174:9000/")
bucket_name = 'gaoch'
print("link success")

def download(localpath, remotepath):
    print(f"downloading {remotepath} to {localpath}")
    s3_client.download_file(Bucket=bucket_name, Key=remotepath, Filename=localpath)



class S3Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.split = split
        # credential = json.load(open("./credentials.json", "r"))
        os.environ['AWS_ACCESS_KEY_ID'] = 'K1DH3djl9BWDEMEv28Ar'
        os.environ['AWS_SECRET_ACCESS_KEY'] = 'W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L'
        self.s3 = boto3.resource("s3",endpoint_url="http://192.168.5.174:9000/")
        self.bucket_name = "gaoch"
        self.train_idxes = [
            '2D_CFD_Rand_M0.1_Eta0.01_Zeta0.01_periodic_128_Train_train_idx.parquet',
   
        ]
        self.val_idxes = [
            '2D_CFD_Rand_M0.1_Eta0.01_Zeta0.01_periodic_128_Train_val_idx.parquet',
         
        ]
        assert split == "train" or split == "val"
        self.idxes = []
        if split == "train":
            self.current_idxes = self.train_idxes
        else:
            self.current_idxes = self.val_idxes
        for idx_file in self.current_idxes:
            if not os.path.exists(idx_file):
                download(idx_file, idx_file)
        for idx_file in self.current_idxes:
            print(f"reading table {idx_file}")
            idx_table = pq.read_table(idx_file)
            print(f"read table {idx_file} finish")
            idx_frame = pd.DataFrame(idx_table[0])
            numpy_array = idx_frame.to_numpy()
            self.idxes.extend(numpy_array)
        print(f"Dataset {self.split} ready! len of data is {len(self.idxes)}")

        means = torch.tensor([0, 0, 0, 0], dtype=torch.float).reshape(4, 1, 1, 1)
        stds = torch.tensor([1, 1, 1, 1], dtype=torch.float).reshape(4, 1, 1, 1)
        self.normalizer = UnitGaussianNormalizer(means, stds)

    def __len__(self):
        return len(self.idxes)

    def __getitem__(self, item):
        # (T,c,h,w)  c:u,v,p,rho
        def get_example_from_s3(dataname):
            # print(dataname)
            obj = self.s3.Object(self.bucket_name, dataname)
            data = obj.get()['Body'].read()
            stream = BytesIO(data)
            table = pq.read_table(source=stream)
            pd_frame = table.to_pandas()
            numpy_array = pd_frame.to_numpy()
            data = torch.tensor(numpy_array)
            data = torch.reshape(data, (21, 128, 128, 4)).permute(3, 0, 1, 2).unsqueeze(0)
            return self.normalizer.encode(data).squeeze(0)
        return get_example_from_s3(str(self.idxes[item][0]))



class FluidDataset(data.Dataset):
    def __init__(self, split):
        self.file = h5py.File('fluid.hdf5', 'r')
        print(self.file.keys())
        self.len=self.file["Vx"].shape[0]
        self.Vx=self.file['Vx']
        self.Vy=self.file["Vy"]
        self.density=self.file["density"]
        self.pressure=self.file["pressure"]
        self.mode=split
        self.normalizer = UnitGaussianNormalizer(*self._concatenate_data())  # 初始化归一化器
        print(self.normalizer.means.shape)
        print(self.normalizer.stds.shape)

    def _concatenate_data(self):
        data = np.stack((self.Vx, self.Vy, self.density, self.pressure), axis=1)
        means = np.reshape(np.mean(data, axis=(0, 2, 3, 4)), (1, 4, 1, 1, 1))
        stds = np.reshape(np.std(data, axis=(0, 2, 3, 4)), (1, 4, 1, 1, 1))
        return torch.from_numpy(means), torch.from_numpy(stds)
    
    def __len__(self):
        if self.mode=="train":
            return int(self.len*0.98)
        else:
            return int(self.len*0.02)

    def __getitem__(self, idx):
        if self.mode=="val":
            idx+=int(self.len*0.98)
        vx = self.Vx[idx]
        vy = self.Vy[idx]
        density = self.density[idx]
        pressure=self.pressure[idx]

        vx = torch.from_numpy(vx)
        vy = torch.from_numpy(vy)
        density = torch.from_numpy(density)
        pressure=torch.from_numpy(pressure)

        frames = torch.stack([vx, vy, density,pressure], dim=0)  # Shape: (4, 21, 128, 128)
     
        # 对数据进行归一化
        frames = self.normalizer.encode(frames.unsqueeze(0)).squeeze(0)

        # # Randomly select two frames
        # indices = torch.randperm(frames.size(1))[:2]
        # ref_idx, dri_idx = torch.min(indices), torch.max(indices)

        # ref_img = frames[:, ref_idx]  # Reference image
        # dri_img = frames[:, dri_idx]  # Driving image
        # out = torch.stack([ref_img, dri_img], dim=1)

        return frames
    

class UnitGaussianNormalizer(object):
    def __init__(self, means, stds, eps=0.00001):
        super(UnitGaussianNormalizer, self).__init__()
        self.means = means
        self.stds = stds
        self.eps = eps

    def encode(self, x):
        # print("x shape",x.shape,"mean shape",self.mean.shape)
        x = (x - self.means.to(x.device)) / (self.stds.to(x.device) + self.eps)
        return x

    def decode(self, x):
        x = (x * self.stds.to(x.device)) + self.means.to(x.device)
        return x



if __name__ == "__main__":
    pass

