from __future__ import annotations

import argparse
from pathlib import Path

import pandas as pd
import numpy as np
from torchvision.datasets.utils import download_url
from tqdm import tqdm
import h5py
import sys

def download_data(root_folder):
    """ "
    Download data splits specific to a given PDE.

    Args:
    root_folder: The root folder where the data will be downloaded
    pde_name   : The name of the PDE for which the data to be downloaded
    """

    pde_list = ["2d_cfd","darcy","swe"]
    meta_df = pd.read_csv("./data_gen/pdebench_data_urls.csv")
    meta_df["PDE"] = meta_df["PDE"].str.lower()

    # Load and parse metadata csv file
    pde_df = meta_df[meta_df["PDE"].isin(pde_list)]

    # Iterate filtered dataframe and download the files
    for _, row in tqdm(pde_df.iterrows(), total=pde_df.shape[0]):
        file_path = Path(root_folder) / row["Path"]
        download_url(row["URL"], file_path, row["Filename"], md5=row["MD5"])


def preprocessing_SWE():
    h5_file = h5py.File(f"./SWE/2D_rdb_NA_NA.h5", "r")
    np.save(f"./SWE/t_grid.npy",np.array(h5_file["0000/grid/t"]))
    np.save(f"./SWE/x_grid.npy",np.array(h5_file["0000/grid/x"]))
    np.save(f"./SWE/y_grid.npy",np.array(h5_file["0000/grid/y"]))

    for i in range(10):
        np.save(f"./SWE/data_{i}.npy",np.array(h5_file[f"000{i}/data"]))
    for i in range(10,100):
        np.save(f"./SWE/data_{i}.npy",np.array(h5_file[f"00{i}/data"]))
    for i in range(100,1000):
        np.save(f"./SWE/data_{i}.npy",np.array(h5_file[f"0{i}/data"]))

def preprocessing_CNSE():
    h5_file = h5py.File("CNSE/2D_CFD_Rand_M1.0_Eta0.1_Zeta0.1_periodic_128_Train.hdf5", "r")
        
    np.save(f"x-coordinate.npy",np.array(h5_file["x-coordinate"]))
    np.save(f"y-coordinate.npy",np.array(h5_file["y-coordinate"]))
    np.save(f"t-coordinate.npy",np.array(h5_file["t-coordinate"]))
    
    Vx = np.save(np.array(h5_file['t-coordinate'], dtype = np.float32))
    
    Vx = np.array(h5_file['Vx'][:1000,:,::2,::2], dtype = np.float32)
    Vx = (Vx-np.mean(Vx))/np.std(Vx)
    Vx_mean = np.mean(Vx) ; Vx_std = np.std(Vx)
    np.save("Vx.npy", Vx) ; np.save("Vx_mean.npy", Vx_mean) ; np.save("Vx_std.npy", Vx_std)

    Vy = np.array(h5_file['Vy'][:1000,:,::2,::2], dtype = np.float32)
    Vy = (Vy-np.mean(Vy))/np.std(Vy)
    Vy_mean = np.mean(Vy) ; Vy_std = np.std(Vy)
    np.save("Vy.npy", Vy) ; np.save("Vy_mean.npy", Vy_mean) ; np.save("Vy_std.npy", Vy_std)

    density = np.array(h5_file['density'][:1000,:,::2,::2], dtype = np.float32)
    density = (density-np.mean(density))/np.std(density)
    density_mean = np.mean(density) ; density_std = np.std(density)
    np.save("density.npy", density) ; np.save("density_mean.npy", density_mean) ; np.save("density_std.npy", Vx_std)

    pressure = np.array(h5_file['pressure'][:1000,:,::2,::2], dtype = np.float32)
    pressure = (pressure-np.mean(pressure))/np.std(pressure)
    pressure_mean = np.mean(pressure) ; pressure_std = np.std(pressure)
    np.save("pressure.npy", pressure) ; np.save("pressure_mean.npy", pressure_mean) ; np.save("pressure_std.npy", pressure_std)

def preprocessing_Darcy():
    h5_file = h5py.File("Darcy01/2D_DarcyFlow_beta0.1_Train.hdf5", "r")
    
    np.save("Darcy01/nu.npy", np.array(h5_file['nu'], dtype = np.float32))
    np.save("Darcy01/tensor.npy", np.array(h5_file['tensor'], dtype = np.float32))
    np.save("Darcy01/x_coordinate.npy", np.array(h5_file['x-coordinate'], dtype = np.float32))
    np.save("Darcy01/y_coordinate.npy", np.array(h5_file['y-coordinate'], dtype = np.float32))

if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(
        prog="Download Script",
        description="Helper script to download the PDEBench datasets",
        epilog="",
    )

    arg_parser.add_argument(
        "--root_folder",
        type=str,
        help="Root folder where the data will be downloaded",
        default="./data_gen/dataset"
    )

    args = arg_parser.parse_args()

    download_data(args.root_folder)
    preprocessing_SWE()
    preprocessing_CNSE()
    preprocessing_Darcy()
