import os
from collections import Counter
from pathlib import Path

import geopandas as gpd
import pandas as pd
import xarray as xr
import numpy as np
from shapely import points


class MetaStation(object):
    def __init__(self, lat_low, lat_up, lon_low, lon_up, n_years=5, control_ratio=0.9, shapefile_path=None,
                 data_path=Path('')):

        self.lat_low = lat_low
        self.lat_up = lat_up
        self.lon_low = lon_low
        self.lon_up = lon_up
        self.n_years = n_years
        self.control_ratio = control_ratio
        self.start_year = 2016
        self.data_path = data_path
        self.file_name='china'
        self.shapefile_path = '' # shapefile_path
        self.filtered_file_name = self.file_name + f'_filtered'

        self.all_station_file = data_path / (f'station/stations/stations_{self.start_year}_2023_{self.file_name}.shp')
        self.station_file = data_path / (
            f'station/stations/stations_{self.start_year}_2023_{self.filtered_file_name}.shp')
        self.data_dir = data_path / f'station_{self.start_year}_yearly.nc4' 
        Path(self.all_station_file).parent.mkdir(exist_ok=True, parents=True)

        if os.path.exists(self.all_station_file):
            self.stations_raw = gpd.read_file(self.all_station_file)
        else:
            print(f"All station file not found at {self.all_station_file}, generating new station table...")
            self.stations_raw = self.generate_station_table()
        self.stations = self.stations_raw

    def generate_station_table(self):
        counter = Counter([])

        for year in range(self.start_year, 2022):
            print(f'Generating station table for year {year}', flush=True)
            data_dir = '' # self.data_path / f'station_{year}_yearly.nc4'
            data = xr.open_dataset(data_dir)
            data['PRE_Time_0808'] = data['PRE_Time_0808'].where(
                (~np.isnan(data['PRE_Time_0808'])) & 
                (np.isfinite(data['PRE_Time_0808'])) &
                (data['PRE_Time_0808'] <= 1000),
                0
            )

            existing_vars = list(data.variables)
            if len(data.number) > 0:
                coords_sub = \
                data[['lon', 'lat']].to_pandas().reset_index(
                    drop=True).drop_duplicates()[['lon', 'lat']].values
                counter = counter + Counter(points(coords_sub))
        counter = gpd.GeoDataFrame(pd.Series(counter).reset_index().rename(columns={'index': 'geometry', 0: 'num'}))
        self.lat_low = counter.geometry.y.min()
        self.lat_up = counter.geometry.y.max()
        self.lon_low = counter.geometry.x.min()
        self.lon_up = counter.geometry.x.max()
        print(f"lat_low: {self.lat_low}, lat_up: {self.lat_up}, lon_low: {self.lon_low}, lon_up: {self.lon_up}")
        counter = counter.set_crs(epsg=4326)

        if self.shapefile_path is not None:
            roi = gpd.read_file(self.shapefile_path).dissolve()
            counter = counter[counter.geometry.within(roi.iloc[0].geometry)]

        counter.to_file(self.all_station_file, index=False)

        return counter

    def generate_filtered_station_table(self, counter):
        return counter
