import os
from pathlib import Path

import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
from shapely import points


class station(object):
    def __init__(self, times, coords_raw, coords, lat_low, lat_up, lon_low, lon_up, file_name, filtered_file_name,
                 n_years=5,
                 data_path=Path('')):

        self.times = times
        self.coords_raw = coords_raw  
        self.coords = coords  
        self.n_years = n_years
        self.years = self.times.dt.year.data
        self.months = self.times.dt.month.data
        self.days = self.times.dt.day.data
        self.hours = self.times.dt.hour.data

        self.lat_low = lat_low
        self.lat_up = lat_up
        self.lon_low = lon_low
        self.lon_up = lon_up

        self.lons = np.array([i.x for i in self.coords])
        self.lats = np.array([i.y for i in self.coords])

        self.lons_raw = np.array([i.x for i in self.coords_raw])
        self.lats_raw = np.array([i.y for i in self.coords_raw])

        self.data_path = data_path

        meta_year_cover = f'Meta--{2024 - self.n_years}--2023'
        meta_year_folder = self.data_path / f'station/processed/{meta_year_cover}'
        station_raw_filename = f'station_{self.years[0]}.nc4'
        station_filename = f'station_{self.years[0]}_filtered.nc4'

        if os.path.exists(meta_year_folder) == False:
            os.system(f'mkdir -p {meta_year_folder}')

        self.station_raw_ds_path = f'{meta_year_folder}/{station_raw_filename}'
        self.station_ds_path = f'{meta_year_folder}/{station_filename}'

        if os.path.exists(self.station_ds_path):
            self.ds_xr = xr.open_dataset(self.station_ds_path)
        else:
            rawdata = self.createRawFile()
            self.ds_xr = self.createFile(rawdata)

    def createRawFile(self):
        if (os.path.exists(self.station_raw_ds_path)):
            rawData = xr.open_dataset(self.station_raw_ds_path)
        else:

            station_ds = self.loadData()
            station_ds = np.stack(station_ds, axis=-1)
            n_variables = station_ds.shape[1]  
            station_var_is_real = np.zeros(station_ds.shape, dtype=bool)
            for i in range(n_variables):
                station_var = station_ds[:, i, :]
                station_var_is_real[:, i, :] = ~np.isnan(station_var)
                station_var = pd.DataFrame(station_var.T).ffill().bfill().values.T
                station_var = np.nan_to_num(station_var, nan=np.nanmean(station_var))
                station_ds[:, i, :] = station_var
            rawData = xr.Dataset(
                {
                    'data': (['stations', 'time'], station_ds[:, 0, :]) ,
                    'data_is_real': (['stations', 'time'], station_var_is_real[:,0, :]),
                    'lon': (['stations'], self.lons_raw),
                    'lat': (['stations'], self.lats_raw),
                },
                coords={
                    'stations': np.arange(1, len(self.coords_raw) + 1),
                    'time': self.times.values,
                },
            )
            rawData.to_netcdf(self.station_raw_ds_path)
        return rawData

    def loadData(self):
        station_ds = []

        unique_years = pd.to_datetime(self.times.values).to_series().dt.year.unique()
        for year in unique_years:
            print(f"Processing station {year}", flush=True)
            try:
                data = self.load_station_yearly(year)
                n_time = min(366, data.dims['time'])  
                for t in range(n_time):
                    da = self.load_station_time(data, t)
                    station_ds.append(da)
            except FileNotFoundError:
                print(f"Warning: File not found for {year}", flush=True)
        return station_ds

    def load_station_yearly(self, year):
        data_path = self.data_path / f'station/raw/{year}.nc4'
        data = xr.open_dataset(data_path, engine='netcdf4')
        year_data = data.sel(time=data.time.dt.year == year)
        year_data['PRE_Time_0808'] = year_data['PRE_Time_0808'].where(
            (~np.isnan(year_data['PRE_Time_0808'])) &
            (np.isfinite(year_data['PRE_Time_0808'])) &
            (year_data['PRE_Time_0808'] <= 1000),
            0
        )
        return year_data

    def load_station_monthly(self, year, month):
        data_path = self.data_path / f'station/raw/{year}.nc'
        data = xr.open_dataset(data_path, engine='netcdf4')
        data['PRE_Time_0808'] = data['PRE_Time_0808'].where(
            (~np.isnan(data['PRE_Time_0808'])) &
            (np.isfinite(data['PRE_Time_0808'])) &
            (data['PRE_Time_0808'] <= 1000),
            0
        )
        return data


    def load_station_time(self, data, t):
        pre_obs = data['PRE_Time_0808'].isel(time=t).values
        lat_obs = data['lat'].values
        lon_obs = data['lon'].values

        coords = points(np.concatenate([lon_obs.reshape(-1, 1), lat_obs.reshape(-1, 1)], axis=1))
        df = gpd.GeoDataFrame({'PRE_Time_0808': pre_obs}, geometry=coords)
        df_agg = df.groupby('geometry', as_index=False).mean()
        df_agg = gpd.GeoDataFrame(df_agg)
        da = np.ones((len(self.coords_raw), 1)) * np.nan
        query_idx = df_agg.sindex.query(self.coords_raw, predicate='contains')
        da[query_idx[0, :], 0] = df_agg['PRE_Time_0808'].values[query_idx[1, :]]
        return da

    def createFile(self, rawdata):
        ind = [e in self.coords for e in self.coords_raw]
        xrdata = rawdata.sel(stations=ind)
        xrdata['stations'] = np.arange(1, len(self.coords) + 1)
        xrdata.to_netcdf(self.station_ds_path)
        return xrdata
