import xarray as xr
import numpy as np
import argparse
import pickle
from pathlib import Path

data_dir = '/usr/commondata/public/covlstm_attention/'

def main(args):
    lat_res = args.lat_res
    lon_res = args.lon_res
    seq_len = args.seq_len
    data = xr.open_mfdataset(data_dir + '*.nc', combine='by_coords')
    lon = np.array(data.longitude)[:lon_res]
    lat = np.array(data.latitude)[:lat_res]

    lat_resolution, lon_resolution = lat.shape[0], lon.shape[0]
    time = data.time.values

    import pandas as pd
    month = pd.Series(time).apply(lambda x: x.month).values
    day = pd.Series(time).apply(lambda x: x.day).values
    hour = pd.Series(time).apply(lambda x: x.hour).values
    time_feature = np.stack([month, day, hour], axis=-1)

    lon, lat = np.meshgrid(lon, lat)
    x = np.stack([lat, lon], axis=-1)
    for ii in range(len(args.datasets)):
        
        if args.datasets[ii] == 'component_of_wind':
            data_u = np.concatenate([np.array(data.u), np.array(data.v)], axis=1)[:,:,:lat_res,:lon_res]
        else:
            data_u = np.array(data.__getattr__(args.attri_names[ii]))[:,:,:lat_res,:lon_res]
        data_u = data_u.transpose(0,2,3,1)
        in_size = data_u.shape[-1]
        
        data_u = data_u.reshape(-1, seq_len, lat_res, lon_res, in_size)
        time_feature = time_feature.reshape(-1, seq_len, 3)
        path = args.output_dirs[ii]

        path_ = Path(path)
        path_.mkdir(exist_ok=True, parents=True)
        data_u = data_u.transpose(0,2,3,1,4)

        sample_num = data_u.shape[0]
        if args.shuffle:
            shuffle_idx = np.random.permutation(np.arange(sample_num))
            data_u = data_u[shuffle_idx]
            time_feature = time_feature[shuffle_idx]

        dataset = {'u': data_u, 'a':data_u[...,0,:], 'space_feature': x, 'time_feature': time_feature}
        with open(path + '/{}.pkl'.format('dataset'), "wb") as f:
            pickle.dump(dataset, f, protocol = 4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--datasets", type=list, default=[
            'component_of_wind',
            'relative_humidity',
            'temperature',
            'cloud_cover'
            ], help="dataset name."
    )
    parser.add_argument(
        "--lat_res", type=int, default=60, help="dataset name."
    )
    parser.add_argument(
        "--lon_res", type=int, default=120, help="dataset name."
    )
    parser.add_argument(
        "--attri_names", type=list, default=['uv','r','t','cc'], help="data name."
    )
    parser.add_argument(
        "--output_dirs", type=list, default=['dataset/component_of_wind','dataset/humidity','dataset/temperature','dataset/cloud_cover'], help="Output directory."
    )
    parser.add_argument(
        "--seq_len", type=int, default=24
    )
    parser.add_argument(
        "--shuffle", type=bool, default=True
    )
    args = parser.parse_args()
    main(args)