import numpy as np
import pandas as pd
import cartopy.crs as ccrs
from typing import *
from tqdm import tqdm
import skyfield.timelib
from skyfield import api, almanac, toposlib
from skyfield.nutationlib import iau2000b
from datetime import datetime, timedelta


def project_coords(lon: np.ndarray, lat: np.ndarray, u: Optional[np.ndarray] = None, v: Optional[np.ndarray] = None,
                   src_proj: ccrs.Projection = ccrs.PlateCarree(),
                   target_proj: ccrs.Projection = ccrs.epsg(3035)):
    """
    Project coordinates lon, lat and optionall vectors u, v into another target projection.
    Default is PlateCarree -> EPSG:3035 (i.e. lon/lat -> europe-specific projection).
    :param lon:
    :param lat:
    :param u:
    :param v:
    :param src_proj:
    :param target_proj:
    :return:
    """
    # function returns ndarray of shape (:, 1, 3),, where the 3rd coordinate is z, which is 0 for us
    x, y = target_proj.transform_points(src_proj, lon, lat)[:, 0, :-1].T
    if u is not None and v is not None:
        u_proj, v_proj = target_proj.transform_vectors(src_proj, lon, lat, u, v)

        return x, y, u_proj, v_proj

    return x, y


def project_coords_df(df: pd.DataFrame, longitude_varname="longitude", latitude_varname="latitude"):
    x_3035, y_3035 = project_coords(df[longitude_varname].values.reshape(-1, 1),
                                    df[latitude_varname].values.reshape(-1, 1))
    return df.assign(x_3035=x_3035,
                     y_3035=y_3035)


def project_coords_and_velocity_df(df: pd.DataFrame,
                                   longitude_varname="longitude", latitude_varname="latitude",
                                   u_varname="u", v_varname="v"):
    x_3035, y_3035, u_3035, v_3035 = project_coords(df[longitude_varname].values.reshape(-1, 1),
                                                    df[latitude_varname].values.reshape(-1, 1),
                                                    df[u_varname].values.reshape(-1, 1),
                                                    df[v_varname].values.reshape(-1, 1),
                                                    )
    return df.assign(x_3035=x_3035,
                     y_3035=y_3035,
                     u_3035=u_3035,
                     v_3035=v_3035)


class CivilTwilightCalculator:
    """
    Calculate Civil Twilight using the skyfield library.
    https://rhodesmill.org/skyfield/planets.html
    """

    DEGREES_CENTER_HORIZON = 0.0
    DEGREES_TOP_HORIZON = 0.26667
    DEGREES_TOP_HORIZON_APPARENTLY = 0.8333
    DEGREES_CIVIL_TWILIGHT = 6.0
    DEGREES_NAUTICAL_TWILIGHT = 12.0
    DEGREES_ASTRONOMICAL_TWILIGHT = 18.0

    def __init__(self):
        self.ts = api.load.timescale()
        # load = api.Loader(data_interim_dir)
        self.load = api.Loader("/tmp")
        self.planets = self.load('de440s.bsp')
        self.sun = self.planets['sun']
        self.earth = self.planets['earth']

    def is_sun_up_closure(self, topos: toposlib.GeographicPosition) -> Callable[[skyfield.timelib.Time], bool]:
        """Build a function of time that returns the daylength.

        The function that this returns will expect a single argument that is a
        :class:`~skyfield.timelib.Time` and will return ``True`` if the sun is up
        or twilight has started, else ``False``.
        """
        topos_at = (self.earth + topos).at

        def is_sun_up_at(t: skyfield.timelib.Time) -> bool:
            """Return `True` if the sun has risen by time `t`."""
            t._nutation_angles = iau2000b(t.tt)
            return topos_at(t).observe(self.sun).apparent().altaz()[0].degrees > -self.DEGREES_CIVIL_TWILIGHT

        is_sun_up_at.rough_period = 0.5  # twice a day
        return is_sun_up_at

    def calc_suntimes(self, latitude, longitude, min_date, max_date) -> List[dict]:
        """
        Return list of times for sunset and sunrise
        """
        loc = api.wgs84.latlon(latitude_degrees=latitude, longitude_degrees=longitude)
        # min_date, max_date = [datetime.strptime(d, "%Y-%m-%d") for d in [min_date, max_date]]
        t0, t1 = [self.ts.utc(d.year, d.month, d.day) for d in [min_date, max_date]]
        time, is_sunrise = almanac.find_discrete(t0, t1, self.is_sun_up_closure(loc), epsilon=1 / 1440)
        return [{"time": t,
                 "is_sunrise": bool(sunrise),
                 "latitude": latitude,
                 "longitude": longitude}
                for (t, sunrise) in zip(time.utc_iso(), is_sunrise)]

    def generate_suntimes_df(self, df) -> pd.DataFrame:
        """
        Calculate Civil twilight sunset and sunrise for the radar stations.
        :param df:
        :return:
        """
        min_date, max_date = df.time.min(), df.time.max()
        suntimes_l = list()

        df_coords = df.loc[:, ["latitude", "longitude"]].drop_duplicates()
        for r in tqdm(df_coords.itertuples(), total=df_coords.shape[0]):
            suntimes_l.append(self.calc_suntimes(latitude=r.latitude, longitude=r.longitude,
                                                 min_date=min_date - np.timedelta64(1, 'D'),
                                                 max_date=max_date + np.timedelta64(2, 'D'))
                              )
        suntimes_wide = (pd.DataFrame(sum(suntimes_l, []))
                         .assign(datetime=lambda x: pd.to_datetime(x.time))
                         .assign(date=lambda x: x.datetime.dt.date,
                                 time=lambda x: x.datetime.dt.time)
                         .pivot(["date", "latitude", "longitude"], "is_sunrise", "time"))
        suntimes_wide.columns = ["sunset", "sunrise"]
        return suntimes_wide.reset_index()

    @staticmethod
    def merge_sunset_table(df):
        """
        Merge sunset dataframe with radar data. I.e. merge date of night to sunset, and the day after the date of the night to sunrise.
        :param df:
        :return:
        """
        civ = CivilTwilightCalculator()
        suntimes_df = civ.generate_suntimes_df(df)
        df_merged_sunset = pd.merge(df, suntimes_df.drop(columns="sunrise"), how='left',
                                    left_on=['latitude', 'longitude', 'date_of_night'],
                                    right_on=['latitude', 'longitude', 'date'])

        df_merged_both = pd.merge(
            df_merged_sunset.assign(date_of_nightp1=lambda x: x.date_of_night + timedelta(days=1)),
            suntimes_df.drop(columns="sunset"), how='left',
            left_on=['latitude', 'longitude', 'date_of_nightp1'],
            right_on=['latitude', 'longitude', 'date'])
        df_merged_both = df_merged_both.assign(
            sunset_datetime=lambda df_: join_datetime(df_, time_var="sunset", date_var="date_of_night"),
            sunrise_datetime=lambda df_: join_datetime(df_, time_var="sunrise", date_var="date_of_nightp1")
        ).drop(columns=["date_x", "date_y"])

        return df_merged_both


def join_datetime(df: pd.DataFrame, time_var: str, date_var: str = "date") -> pd.Series:
    fun = lambda x: datetime.combine(x[0], x[1])
    return df.loc[:, [date_var, time_var]].apply(fun, axis=1).dt.round("1min")
