import numpy as np
import xarray as xr
import os
from glob import glob
from datetime import datetime, timedelta
from scipy.interpolate import interp1d
import netCDF4 as nc
from multiprocessing import Pool, cpu_count


# Interpolation function
def interp2(arr, x):
    original_shape = arr.shape
    interpolated_arr = np.zeros((original_shape[0], x, original_shape[2]))
    for i in range(original_shape[0]):
        for k in range(original_shape[2]):
            slice_1d = arr[i, :, k]
            interp_func = interp1d(
                range(original_shape[1]), slice_1d, kind='linear', fill_value="extrapolate"
            )
            new_indices = np.linspace(0, original_shape[1] - 1, num=x)
            interpolated_arr[i, :, k] = interp_func(new_indices)
    return interpolated_arr


# Interpolation function for 1D array
def interp1(arr, x):
    original_length = len(arr)
    interp_func = interp1d(range(original_length), arr, kind='linear', fill_value="extrapolate")
    new_indices = np.linspace(0, original_length - 1, num=x)
    interpolated_arr = interp_func(new_indices)
    return interpolated_arr


# Processing function for a single date
def process_date(current_date):
    date_str = current_date.strftime('%Y%m%d')
    print(f"Processing file for date: {date_str}")

    # Construct file path
    dataset_path = glob(f'/here/are/original/nc_files/{date_str[0:4]}/mercatorglorys12v1_gl12_mean_{date_str}*.nc')

    if not dataset_path or not os.path.exists(dataset_path[0]):
        print(f"File for {date_str} does not exist. Skipping.")
        return

    # Read dataset
    dataset = xr.open_dataset(dataset_path[0])

    try:
        depth = dataset['depth']
        selected_depth_indices = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 23]
        depth = depth[selected_depth_indices]
        print(list(depth.values))

        # lat = interp1(dataset['latitude'].values[0:2041:3], 721)
        # lon = dataset['longitude'].values[0:4320:3]

        # Subsample and reshape the data
        thetao = dataset['thetao'].values.reshape(50, 2041, 4320)[selected_depth_indices, 0:2041:3, 0:4320:3]
        so = dataset['so'].values.reshape(50, 2041, 4320)[selected_depth_indices, 0:2041:3, 0:4320:3]
        uo = dataset['uo'].values.reshape(50, 2041, 4320)[selected_depth_indices, 0:2041:3, 0:4320:3]
        vo = dataset['vo'].values.reshape(50, 2041, 4320)[selected_depth_indices, 0:2041:3, 0:4320:3]

        # Handle zos differently since it has different dimensions
        zos_2d = dataset['zos'].values.reshape(2041, 4320)[0:2041:3, 0:4320:3]
        # Create a 3D array by repeating the 2D zos data across all selected levels
        zos = np.repeat(zos_2d[np.newaxis, :, :], len(selected_depth_indices), axis=0)

        # Interpolate to desired shape
        thetao = interp2(thetao, 721)
        so = interp2(so, 721)
        uo = interp2(uo, 721)
        vo = interp2(vo, 721)
        zos = interp2(zos, 721)

        # Combine into one variable
        ocean = np.stack([thetao, so, uo, vo, zos], axis=0)

        # Confirm shape
        print(f"ocean shape: {ocean.shape}")

        # Save to NetCDF
        output_path = os.path.join('/here/are/outoput/files/BUT/original/deep_nc_files', f'upper_{date_str}.nc')
        nc_file = nc.Dataset(output_path, 'w', format='NETCDF4')

        # Create dimensions
        nc_file.createDimension('time', 1)
        nc_file.createDimension('variable', 5)
        nc_file.createDimension('level', len(selected_depth_indices))
        nc_file.createDimension('latitude', 721)
        nc_file.createDimension('longitude', 1440)

        # Create time variable
        time_var = nc_file.createVariable('time', 'f8', ('time',))
        time_var.units = 'hours since 2008-12-27 00:00:00'
        time_var.long_name = 'Time'
        time_var.calendar = 'gregorian'
        time_with_hour = current_date + timedelta(hours=12)
        time_var[0] = nc.date2num(time_with_hour, units=time_var.units, calendar=time_var.calendar)


        temperature = nc_file.createVariable('thetao', 'f4', ('time', 'level', 'latitude', 'longitude'))
        salinity = nc_file.createVariable('so', 'f4', ('time', 'level', 'latitude', 'longitude'))
        velocity1 = nc_file.createVariable('uo', 'f4', ('time', 'level', 'latitude', 'longitude'))
        velocity2 = nc_file.createVariable('vo', 'f4', ('time', 'level', 'latitude', 'longitude'))
        zheight = nc_file.createVariable('zos', 'f4', ('time', 'level', 'latitude', 'longitude'))


        temperature[:] = ocean[0]
        salinity[:] = ocean[1]
        velocity1[:] = ocean[2]
        velocity2[:] = ocean[3]
        zheight[:] = ocean[4]


        temperature.units = 'K'
        temperature.long_name = 'Temperature'

        salinity.units = 'psu'
        salinity.long_name = 'Salinity'

        velocity1.units = 'm/s'
        velocity1.long_name = 'Velocity 1'

        velocity2.units = 'm/s'
        velocity2.long_name = 'Velocity 2'

        zheight.units = 'm'
        zheight.long_name = 'Sea Surface Height'

    finally:
        nc_file.close()


    print(f"Finished processing file for date: {date_str}")


# Main function
def main():
    start_date = datetime(2004, 12, 27)
    end_date = datetime(2006, 1, 1)
    date_list = [start_date + timedelta(days=i) for i in range((end_date - start_date).days + 1)]

    # Use multiprocessing to parallelize the processing
    num_processes = min(cpu_count(), 6)  # Use all available CPU cores or limit to 32
    with Pool(processes=num_processes) as pool:
        pool.map(process_date, date_list)


if __name__ == "__main__":
    main()
    
