# loading MHAD dataset for training and testing

import numpy as np
import torch.utils.data as data
import h5py
import torch

from io import BytesIO
import os.path
import os
from botocore.client import Config

import torch
from torch.utils.data import DataLoader
import pyarrow.parquet as pq
import pandas as pd
import boto3
import json
from torchvision.transforms import Resize
import random
import numpy as np

print("start link")
# s3_client = boto3.client('s3', aws_access_key_id='K1DH3djl9BWDEMEv28Ar', aws_secret_access_key='W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L',
#                          endpoint_url="http://100.64.0.3:9000/")
s3_client = boto3.client('s3', aws_access_key_id='K1DH3djl9BWDEMEv28Ar', aws_secret_access_key='W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L',
                         endpoint_url="http://192.168.5.174:9000/")
bucket_name = 'gaoch'
print("link success")

def download(localpath, remotepath):
    print(f"downloading {remotepath} to {localpath}")
    s3_client.download_file(Bucket=bucket_name, Key=remotepath, Filename=localpath)



class S3Dataset(torch.utils.data.Dataset):
    def __init__(self, split, data_path):
        self.split = split
        # credential = json.load(open("./credentials.json", "r"))
        os.environ['AWS_ACCESS_KEY_ID'] = 'K1DH3djl9BWDEMEv28Ar'
        os.environ['AWS_SECRET_ACCESS_KEY'] = 'W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L'
        self.s3 = boto3.resource("s3",endpoint_url="http://192.168.5.174:9000/")
        self.bucket_name = "gaoch"
        self.train_idxes = [
            f'{data_path}_train_idx.parquet',
        ]
        self.val_idxes = [
            f'{data_path}_val_idx.parquet',
        ]
        assert split == "train" or split == "val"
        self.idxes = []
        if split == "train":
            self.current_idxes = self.train_idxes
        else:
            self.current_idxes = self.val_idxes
        for idx_file in self.current_idxes:
            if not os.path.exists(idx_file):
                download(idx_file, idx_file)
        for idx_file in self.current_idxes:
            print(f"reading table {idx_file}")
            idx_table = pq.read_table(idx_file)
            print(f"read table {idx_file} finish")
            idx_frame = pd.DataFrame(idx_table[0])
            numpy_array = idx_frame.to_numpy()
            self.idxes.extend(numpy_array)
        print(f"Dataset {self.split} ready! len of data is {len(self.idxes)}")

        obj = self.s3.Object('gaoch', f'{data_path}__mean_std.parquet')
        data = obj.get()['Body'].read()
        stream = BytesIO(data)
        table = pq.read_table(source=stream)
        pd_frame = table.to_pandas()
        numpy_array = pd_frame.to_numpy()
        data = torch.tensor(numpy_array, dtype=torch.float).reshape(2, 4)

        means = data[0].reshape(1, 4, 1, 1, 1)
        stds = data[1].reshape(1, 4, 1, 1, 1)
        self.normalizer = UnitGaussianNormalizer(means, stds)

    def __len__(self):
        return len(self.idxes)

    def __getitem__(self, item):
        # (T,c,h,w)  c:u,v,p,rho
        def get_example_from_s3(dataname):
            # print(dataname)
            obj = self.s3.Object(self.bucket_name, dataname)
            data = obj.get()['Body'].read()
            stream = BytesIO(data)
            table = pq.read_table(source=stream)
            pd_frame = table.to_pandas()
            numpy_array = pd_frame.to_numpy()
            data = torch.tensor(numpy_array)
            data = torch.reshape(data, (21, 128, 128, 4)).permute(3, 0, 1, 2).unsqueeze(0)
            data = self.normalizer.encode(data).squeeze(0)

            indices = torch.randperm(data.size(1))[:2]
            ref_idx, dri_idx = torch.min(indices), torch.max(indices)

            ref_img = data[:, ref_idx]  # Reference image
            dri_img = data[:, dri_idx]  # Driving image
            out = torch.stack([ref_img, dri_img], dim=1)

            return out

        return get_example_from_s3(str(self.idxes[item][0]))


class DiffDataset(data.Dataset):   
    def __init__(self,  mode="train"):
        # self.file = h5py.File(self.data_param["data_path"], 'r')
        file=h5py.File("./fluid.h5", 'r')
        # file=h5py.File("C:/Users/ryb/Downloads/09-24-2023-19-53-54_files_list/2D_diff-react_NA_NA.h5", 'r')

        self.len=1000
        cases=[]
        for i in range(1000):
            number = str(i).zfill(4)
            case=file[number]["data"] #(101,128,128,2)
            case=torch.tensor(np.array(case)).permute(3,0,1,2) #(2,101,128,128)
            cases.append(case.unsqueeze(0))
        self.data = torch.cat(cases, dim=0) #(1000,2,101,128,128)
        means = torch.tensor([-0.0311, -0.0199]).reshape(1, 2, 1, 1)
        stds = torch.tensor([0.1438, 0.1117]).reshape(1, 2, 1, 1)
        print(self.data.shape)
        print(means)
        print(stds)
        self.mode=mode
        self.normalizer = UnitGaussianNormalizer(means, stds)  # 初始化归一化器

    
    def __len__(self):
        if self.mode=="train":
            return int(self.len*0.9)
        else:
            return int(self.len*0.1)

    def __getitem__(self, idx):
        # (2,128,128)
        if self.mode=="val":
            idx+=int(self.len*0.9)

        frames=self.data[idx//101, :, idx%101, :, :]  # (2,128,128)
        frames = self.normalizer.encode(frames.unsqueeze(0)).squeeze(0)
        return frames


class DoubleShockDataset(torch.utils.data.Dataset):
    def __init__(self, split):

        os.environ['AWS_ACCESS_KEY_ID'] = 'K1DH3djl9BWDEMEv28Ar'
        os.environ['AWS_SECRET_ACCESS_KEY'] = 'W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L'
        # self.s3 = boto3.resource("s3",endpoint_url="http://100.64.0.3:9000/")
        self.s3 = boto3.resource("s3",endpoint_url="http://192.168.5.174:9000/")
        self.bucket_name = "datasets"
        self.path="FluidDataset/doubleshock2npy/"
        # self.path="FluidDataset/castroRT/"

        self.train_idxes=[]
        self.bucket = self.s3.Bucket(self.bucket_name)
        self.split=split

        means = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape(1, 7, 1, 1)
        stds = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).reshape(1, 7, 1, 1)
        self.normalizer = UnitGaussianNormalizer(means, stds)
        self.transform = Resize((256, 256))

        assert split == "train" or split == "val"
        print(f"Dataset {self.split} ready!")

    def __len__(self):
        return int(20*645*0.9) if self.split == "train" else int(20*645*0.1)

    def __getitem__(self, item):
        if self.split == 'val':
            item += int(20*645*0.9)

        index = item % 645 + 1
        dataname = f'{self.path}{item//645}/{index}.npy'
        # print(dataname)
        try:
            self.s3.Bucket('datasets').download_file(dataname, f'temp.npy')
        except:
            print(dataname)
            self.s3.Bucket('datasets').download_file(f'{self.path}1/1.npy', 'temp.npy')
        
        a = np.load(f'temp.npy')
        a = torch.from_numpy(a).permute(2, 0, 1)
        a = self.transform(a)

        return a
    
class CastroRMDataset(torch.utils.data.Dataset):
    def __init__(self, split):

        os.environ['AWS_ACCESS_KEY_ID'] = 'K1DH3djl9BWDEMEv28Ar'
        os.environ['AWS_SECRET_ACCESS_KEY'] = 'W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L'
        # self.s3 = boto3.resource("s3",endpoint_url="http://100.64.0.3:9000/")
        self.s3 = boto3.resource("s3",endpoint_url="http://192.168.5.174:9000/")
        self.bucket_name = "datasets"
        # self.path="FluidDataset/doubleshock2npy/"
        self.path="FluidDataset/castroRT2npy/"

        self.train_idxes=[]
        self.bucket = self.s3.Bucket(self.bucket_name)
        self.split=split

        means = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape(1, 7, 1, 1)
        stds = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).reshape(1, 7, 1, 1)
        self.normalizer = UnitGaussianNormalizer(means, stds)
        self.transform = Resize((256, 256))

        assert split == "train" or split == "val"
        print(f"Dataset {self.split} ready!")

    def __len__(self):
        return int(5*9024*0.9) if self.split == "train" else int(5*9024*0.1)

    def __getitem__(self, item):
        if self.split == 'val':
            item += int(5*9024*0.9)

        dataname = f'{self.path}{item//9024}/{item%9024+1}.npy'
        # print(dataname)
        
        try:
            self.s3.Bucket('datasets').download_file(dataname, f'temp.npy')
        except:
            print(dataname)
            self.s3.Bucket('datasets').download_file(f'{self.path}1/1.npy', 'temp.npy')
        
        a = np.load(f'temp.npy')
        a = torch.from_numpy(a).permute(2, 0, 1)
        a = self.transform(a)

        return a
    # def download(self,localpath, remotepath):
    #     print(f"downloading {remotepath} to {localpath}")
    #     self.s3_client.download_file(Bucket=self.bucket_nameG, Key=remotepath, Filename=localpath)
# class CastroRTDataset(torch.utils.data.Dataset):
#     def __init__(self, split):
#         # print("start link")
#         # self.s3_client = boto3.client('s3', aws_access_key_id='K1DH3djl9BWDEMEv28Ar', aws_secret_access_key='W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L',
#         #                         endpoint_url="http://192.168.5.174:9000/")
#         # self.bucket_nameG = 'gaoch'
#         # print("link success")

#         os.environ['AWS_ACCESS_KEY_ID'] = 'K1DH3djl9BWDEMEv28Ar'
#         os.environ['AWS_SECRET_ACCESS_KEY'] = 'W18vVnFODYU7LtBsHAqWHI62fM64cv8BK6mHA22L'
#         self.s3 = boto3.resource("s3",endpoint_url="http://192.168.5.174:9000/")
#         self.bucket_name = "datasets"
#         self.path="FluidDataset/castroRT/"
#         # self.path="FluidDataset/castroRT/"

#         self.train_idxes=[]
#         self.bucket = self.s3.Bucket(self.bucket_name)
#         self.split=split
#         # for i in range(2,10):
#         #     self.train_idxes.append(f"{i}")
#         for i in range(9):
#             self.train_idxes.append(f"{i}")

#         # self.val_idxes=["18","19","20"]
#         self.val_idxes=["9","10"]


#         assert split == "train" or split == "valid"
#         self.idxes = []
#         if split == "train":
#             self.current_idxes = self.train_idxes
#         else:
#             self.current_idxes = self.val_idxes
#         print(f"Dataset {self.split} ready!")

#     def __len__(self):
#         return 4800 if self.split == "train" else 1800

#     def __getitem__(self, item):
#         # (5,262,262)
#         def get_example_from_s3(dataname):
#             obj = self.s3.Object(self.bucket_name, dataname)
#             data = obj.get()['Body'].read()
#             stream = BytesIO(data)
#             reader=TecReader(stream)


#             return reader.data[:,:256,:256] # (c,h,w) (7,262,262)

#         case=random.choice(self.current_idxes)
#         tecs_of_case=[]
#         number_list=[]
#         print(self.path+case)
#         for obj in self.bucket.objects.filter(Prefix=self.path+case):
#             # print(obj.key)
#             tecs_of_case.append(obj.key)
#             parts = obj.key.split('/')
#             if parts[-1] != 'final.tec':
#                 number = parts[-1].split('.')[0]
#                 # print(number,"is digit:",number.isdigit())
#                 if number.isdigit():
#                     number_list.append(int(number))
#         number_list.sort()
#         idx_ref=random.choice(number_list[:-21])
#         idx_dri=idx_ref+1
#         idx_val1=idx_ref+10
#         idx_val2=idx_ref+20

#         ref_img=get_example_from_s3(f'{self.path}{case}/{idx_ref}.tec')
#         dri_img=get_example_from_s3(f'{self.path}{case}/{idx_dri}.tec')
#         val1_img=get_example_from_s3(f'{self.path}{case}/{idx_val1}.tec')
#         val2_img=get_example_from_s3(f'{self.path}{case}/{idx_val2}.tec')
#         out={}
#         out["source"]=ref_img
#         out["driving"]=dri_img
#         out["validate1"]=val1_img
#         out["validate2"]=val2_img

#         return out
    
#     # def download(self,localpath, remotepath):
#     #     print(f"downloading {remotepath} to {localpath}")
#     #     self.s3_client.download_file(Bucket=self.bucket_nameG, Key=remotepath, Filename=localpath)


class FluidDataset(data.Dataset):
    def __init__(self, split):
        self.file = h5py.File('fluid.hdf5', 'r')
        print(self.file.keys())
        self.len=self.file["Vx"].shape[0]
        self.Vx=self.file['Vx']
        self.Vy=self.file["Vy"]
        self.density=self.file["density"]
        self.pressure=self.file["pressure"]
        self.mode=split
        self.normalizer = UnitGaussianNormalizer(*self._concatenate_data())  # 初始化归一化器
        print(self.normalizer.means.shape)
        print(self.normalizer.stds.shape)
        self.transform = Resize((256, 256))

    def _concatenate_data(self):
        data = np.stack((self.Vx, self.Vy, self.density, self.pressure), axis=1)
        means = np.reshape(np.mean(data, axis=(0, 2, 3, 4)), (1, 4, 1, 1, 1))
        stds = np.reshape(np.std(data, axis=(0, 2, 3, 4)), (1, 4, 1, 1, 1))
        return torch.from_numpy(means), torch.from_numpy(stds)
    
    def __len__(self):
        if self.mode=="train":
            return int(self.len*0.98)
        else:
            return int(self.len*0.02)

    def __getitem__(self, idx):
        if self.mode=="val":
            idx+=int(self.len*0.98)
        vx = self.Vx[idx]
        vy = self.Vy[idx]
        density = self.density[idx]
        pressure=self.pressure[idx]

        vx = torch.from_numpy(vx)
        vy = torch.from_numpy(vy)
        density = torch.from_numpy(density)
        pressure=torch.from_numpy(pressure)

        frames = torch.stack([vx, vy, density,pressure], dim=0)  # Shape: (3, 22, 128, 128)

        frames = self.transform(frames)
        # 对数据进行归一化
        frames = self.normalizer.encode(frames.unsqueeze(0)).squeeze(0)

        # Randomly select two frames
        indices = torch.randperm(frames.size(1))[:2]
        ref_idx, dri_idx = torch.min(indices), torch.max(indices)

        ref_img = frames[:, ref_idx]  # Reference image
        dri_img = frames[:, dri_idx]  # Driving image
        out = torch.stack([ref_img, dri_img], dim=1)

        return out
    

class UnitGaussianNormalizer(object):
    def __init__(self, means, stds, eps=0.00001):
        
        super(UnitGaussianNormalizer, self).__init__()
        self.means = means
        self.stds = stds
        self.eps = eps

    def encode(self, x):
        # print("x shape",x.shape,"mean shape",self.mean.shape)
        x = (x - self.means.to(x.device)) / (self.stds.to(x.device) + self.eps)
        return x

    def decode(self, x):
        x = (x * self.stds.to(x.device)) + self.means.to(x.device)
        return x

