#!/usr/bin/env python
# coding: utf-8

#import climetlab as cml
import numpy as np
from scipy.interpolate import interpn
from datetime import timedelta, datetime


path = '/home/jovyan/data/'
save_path = '/home/jovyan/prepaired_data/'

year_train = '2015'
year_test = '2016'

np.random.seed(0)

import numpy as np
from sklearn.utils import shuffle

pict = 1
StBolz4rt = 0.015431332316368392  # 4th root of Stefan-Boltzmann Constant
SecondsInDay = 86400   # seconds in 24 hours

wrf_nlev = 49
targets = ['RTHRATSW', 'RTHRATLW']
#targets = ['RTHRATLW']
#targets = ['RTHRATSW']
n_targets = len(targets)

fluxes  = ['SWUPT','SWUPB', 'SWDNB',  'LWUPT', 'LWUPB', 'LWDNB']
# convert fluxes list to 2D
targ = [0, 0, 0, 1, 1, 1]
flx  = [0, 1, 2, 0, 1, 2]

n_toa_fluxes = 1
n_srf_fluxes = 2

assert len(fluxes) == n_targets * (n_toa_fluxes + n_srf_fluxes)

sca_features = ['XLAND', 'COSZEN', 'SOLCON', 'ALBEDO', 'XLONG', 'XLAT', 'HGT', 'EMISS', 'TSK']  # 'ALBBCKTSK' solar_irradiance LWP_NEW always zero
# LW ONLY
#sca_features = ['XLAND', 'XLONG', 'XLAT', 'HGT','EMISS', 'TSK']  # 'ALBBCKTSK' solar_irradiance LWP_NEW always zero
col_features = ['SH', 'CF', 'RH', 'P_TOT', 'T_KEL', 'QCLOUD', 'QRAIN', 'QSNOW', 'QVAPOR', 'QICE'] #'CLDFRA' always zero !!!
                    #'o3_mmr', 'co2_vmr', 'n2o_vmr', 'ch4_vmr', # q: humidity
                    #'cloud_fraction', 'q_liquid', 'q_ice', 're_liquid', 're_ice', # 'overlap_param',
                    #'fractional_std'] #, 'inv_cloud_effective_size']


for year in [year_train, year_test]:
    data_folder = path + year
    save_folder = save_path + year

    y = []
    for feature in targets:
        arr = np.load(data_folder + '/{}.npy'.format(feature))
        arr = np.transpose(np.array(arr), (1, 0, 2))
        arr = np.reshape(arr, (wrf_nlev, -1))
        y.append(arr)
    y = np.array(y)
    y = np.transpose(y, (2,0,1))
    y = np.flip(y, 2)

    n_samples, n_targets, n_levels = y.shape
    assert n_levels == wrf_nlev

    seq_data, seq_features = [], []
    for feature in col_features:
        arr = np.load(data_folder + '/{}.npy'.format(feature))
        arr = np.transpose(arr, (1,0,2))
        arr = np.reshape(arr, (n_levels, n_samples))
        seq_data.append(arr)
        seq_features.append(feature)

        # dT/dz dP/dz
        if (feature == 'T_KEL'):
            tmp = np.concatenate(
                     (arr[1:,] - arr[:-1,:],
                     np.zeros((1,n_samples))), axis=0)
            seq_data.append(tmp)
            seq_features.append('d'+feature)
            del tmp

    # mins since 2015-01-15
    # 2D array: hours times profiles
    xtime = np.load(data_folder + '/XTIME.npy')
    # 1D array: hours only
    xtime_hour = xtime[:,0] # single time value for all (200) prfiles at that instance
    print('covering period from ', datetime(2015, 1, 15) + timedelta(minutes=xtime_hour[0 ].item()),
                           ' to ', datetime(2015, 1, 15) + timedelta(minutes=xtime_hour[-1].item()) )

    #needed plot profiles & interpolate ozone data
    lons = np.load(data_folder + '/XLONG.npy')
    lats = np.load(data_folder + '/XLAT.npy')

    p_tot = np.load(data_folder + '/P_TOT.npy')


    # CAM/RRTMG ozone data
    lonsiz=1
    levsiz=59
    latsiz=64
    num_months=13
    plev = np.loadtxt('ozone_plev.txt')
    lat_ozone = np.loadtxt('ozone_lat.txt')
    ozmixin = np.loadtxt('ozone_formatted.txt')
    ozmixin = np.reshape(ozmixin, (levsiz, latsiz, num_months), order='F')
    dates = np.array([datetime(2015, 1, 15) + timedelta(minutes=i.item()) for i in xtime_hour])
    Nt, Nz, Nxy = p_tot.shape
    ozone = interpn( (plev, lat_ozone, np.arange(num_months)), ozmixin,
                        np.array(
                            [[[[p_tot[t, k, j], lats[t,j], dates[t].month]
                                for j in range(Nxy)] for k in range(Nz)] for t in range(Nt)]
                            ),
                        bounds_error=False, fill_value=0.0)
    print("ozone data interpolated ", ozone.shape, ozone.min(), ozone.max())
    del xtime_hour, dates

    ozone = np.transpose(ozone, (1,0,2))
    ozone = np.reshape(ozone, (n_levels, n_samples))
    seq_data.append(ozone)
    seq_features.append('OZONE')


    lons = np.reshape(lons, n_samples)
    lats = np.reshape(lats, n_samples)
    xtime = np.reshape(xtime, n_samples)

    seq_data = np.transpose(np.array(seq_data), (2, 0, 1))
    ## reverse order of layers
    seq_data = np.flip(seq_data, 2)
    print('seq_data reverted')
    #seq_data = np.array(seq_data)

    sca_data = []
    for feature in sca_features:
        arr = np.load(data_folder + '/{}.npy'.format(feature))
        arr = np.reshape(arr, n_samples)
        sca_data.append(arr)

    sca_data = np.transpose(np.array(sca_data))

    flux_data = []
    for feature in fluxes:
        arr = np.load(data_folder + '/{}.npy'.format(feature))
        arr = np.reshape(arr, n_samples)
        flux_data.append(arr)

    flux_data = np.transpose(np.array(flux_data))

    # (n_samples, n_features, *n_layer)
    print(save_folder + ': shapes are y {}; sca_data {}; seq_data {} flux_data {}'.format(y.shape, sca_data.shape, seq_data.shape, flux_data.shape))

    shuffled_ids = shuffle(range(n_samples), random_state=0)

    sca_data  = sca_data[shuffled_ids,  :]
    flux_data = flux_data[shuffled_ids, :]
    y         = y[shuffled_ids, :, :]
    seq_data  = seq_data[shuffled_ids, :, :]

    lons = lons[shuffled_ids]
    lats = lats[shuffled_ids]
    xtime = xtime[shuffled_ids]

    print(save_folder + ': shapes are y {}; sca_data {}; seq_data {} flux_data {}'.format(y.shape, sca_data.shape, seq_data.shape, flux_data.shape))

    with open(save_folder + '/y.npy', 'wb') as f:
        #print('saving y', target, y.min(), y.max())
        np.save(f, y)
    with open(save_folder + '/sca_data.npy', 'wb') as f:
        np.save(f, sca_data)
    with open(save_folder + '/seq_data.npy', 'wb') as f:
        np.save(f, seq_data)
    with open(save_folder + '/flux_data.npy', 'wb') as f:
        np.save(f, flux_data)

