from pathlib import Path
import pandas as pd
import numpy as np
import xarray as xr
class ERA5(object):
    def __init__(self, lat_low, lat_up, lon_low, lon_up, year, region='Northeastern',
                 data_path=Path('')):
        buffer = 1.5
        self.lat_low = np.floor(lat_low - buffer)
        self.lat_up = np.ceil(lat_up + buffer)
        self.lon_low = np.floor(lon_low - buffer)
        self.lon_up = np.ceil(lon_up + buffer)
        self.year = year
        self.data_path = data_path
        self.file_path = self.data_path / 'ERA5' / 'Processed' / f'era5_{year}.nc'
        self.data = self.load_ERA5(region)
        self.tp_min = self.data.tp.values.min()
        self.tp_max = self.data.tp.values.max()

    def load_ERA5(self, region):

        if self.file_path.exists():
            return xr.open_dataset(self.file_path)

        data_list = []

        
        data = self.load_ERA5_yearly(region)

        data_list.append(data)

        data = xr.concat(data_list, dim='time')
        data = data[[
            'tp'
        ]]
        self.file_path.parent.mkdir(exist_ok=True, parents=True)
        data.to_netcdf(self.file_path)

        return data

    def load_ERA5_yearly(self, region):
        data_path = (
        self.data_path 
        / 'ERA5' 
        / 'ifs_hres' 
        / f'{self.year}.nc' 
        )
        data = xr.open_dataset(data_path)
        data = data.sel(longitude=slice(self.lon_low, self.lon_up), latitude=slice(self.lat_low, self.lat_up))
        data = data.rename({'total_precipitation_24hr': 'tp'})
        return data