# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Directory to extract time covariates.

Extract time covariates from datetime.
"""

import numpy as np
import pandas as pd
from pandas.tseries.holiday import EasterMonday
from pandas.tseries.holiday import GoodFriday
from pandas.tseries.holiday import Holiday
from pandas.tseries.holiday import SU
from pandas.tseries.holiday import TH
from pandas.tseries.holiday import USColumbusDay
from pandas.tseries.holiday import USLaborDay
from pandas.tseries.holiday import USMartinLutherKingJr
from pandas.tseries.holiday import USMemorialDay
from pandas.tseries.holiday import USPresidentsDay
from pandas.tseries.holiday import USThanksgivingDay
from pandas.tseries.offsets import DateOffset
from pandas.tseries.offsets import Day
from pandas.tseries.offsets import Easter
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from joblib import Parallel, delayed


# This is 183 to cover half a year (in both directions), also for leap years
# + 17 as Eastern can be between March, 22 - April, 25
MAX_WINDOW = 183 + 17


def _distance_to_holiday(holiday):
  """Return distance to given holiday."""

  def _distance_to_day(index):
    holiday_date = holiday.dates(
        index - pd.Timedelta(days=MAX_WINDOW),
        index + pd.Timedelta(days=MAX_WINDOW),
    )
    assert (
        len(holiday_date) != 0  # pylint: disable=g-explicit-length-test
    ), f"No closest holiday for the date index {index} found."
    # It sometimes returns two dates if it is exactly half a year after the
    # holiday. In this case, the smaller distance (182 days) is returned.
    return (index - holiday_date[0]).days

  return _distance_to_day


EasterSunday = Holiday(
    "Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)]
)
NewYearsDay = Holiday("New Years Day", month=1, day=1)
SuperBowl = Holiday(
    "Superbowl", month=2, day=1, offset=DateOffset(weekday=SU(1))
)
MothersDay = Holiday(
    "Mothers Day", month=5, day=1, offset=DateOffset(weekday=SU(2))
)
IndependenceDay = Holiday("Independence Day", month=7, day=4)
ChristmasEve = Holiday("Christmas", month=12, day=24)
ChristmasDay = Holiday("Christmas", month=12, day=25)
NewYearsEve = Holiday("New Years Eve", month=12, day=31)
BlackFriday = Holiday(
    "Black Friday",
    month=11,
    day=1,
    offset=[pd.DateOffset(weekday=TH(4)), Day(1)],
)
CyberMonday = Holiday(
    "Cyber Monday",
    month=11,
    day=1,
    offset=[pd.DateOffset(weekday=TH(4)), Day(4)],
)

HOLIDAYS = [
    EasterMonday,
    GoodFriday,
    USColumbusDay,
    USLaborDay,
    USMartinLutherKingJr,
    USMemorialDay,
    USPresidentsDay,
    USThanksgivingDay,
    EasterSunday,
    NewYearsDay,
    SuperBowl,
    MothersDay,
    IndependenceDay,
    ChristmasEve,
    ChristmasDay,
    NewYearsEve,
    BlackFriday,
    CyberMonday,
]

def get_sin_cos_embedding(x: np.ndarray, dim: int) -> np.ndarray:
    assert dim % 2 == 0, "Embedding dimension must be even"
    if x.ndim == 2:
      x = x.squeeze()
    x = x[:, None] # (T, 1)
    div_term = np.exp(np.arange(0, dim, 2) * -(np.log(10000.0) / dim))
    sin_embed = np.sin(x * div_term)
    cos_embed = np.cos(x * div_term)
    return np.concatenate([sin_embed, cos_embed], axis=1)

class TimeCovariates(object):
  """Extract all time covariates except for holidays."""

  def __init__(
      self,
      datetimes,
      normalized=True,
      holiday=False,
      sincos_embed_dim : int = 0,
  ):
    """Init function.

    Args:
      datetimes: pandas DatetimeIndex (lowest granularity supported is min)
      normalized: whether to normalize features or not
      holiday: fetch holiday features or not
      sincos_embed_dim: dimension of sincos embedding. If 0, no embedding is applied.
    
    Returns:
      None
    """
    self.normalized = normalized
    self.dti = datetimes
    self.holiday = holiday
    self.sincos_embed_dim = sincos_embed_dim

  def _normalize_time_feature(self, values, max_val):
    if self.normalized:
      return values / max_val - 0.5
    return values

  def _minute_of_hour(self):
    return self._normalize_time_feature(
      np.array(self.dti.minute, dtype=np.float32), # minutes
      max_val=59.0,
    )

  def _hour_of_day(self):
    return self._normalize_time_feature(
      np.array(self.dti.hour, dtype=np.float32), # hours
      max_val=23.0,
    )

  def _day_of_week(self):
    return self._normalize_time_feature(
      np.array(self.dti.dayofweek, dtype=np.float32), # days of week
      max_val=6.0,
    )

  def _day_of_month(self):
    return self._normalize_time_feature(
      np.array(self.dti.day, dtype=np.float32), # days of month
      max_val=30.0
    )

  def _day_of_year(self):
    return self._normalize_time_feature(
      np.array(self.dti.dayofyear, dtype=np.float32),
      max_val=364.0
    )

  def _month_of_year(self):
    return self._normalize_time_feature(
      np.array(self.dti.month, dtype=np.float32), # months
      max_val=11.0,
    )

  def _week_of_year(self):
    return self._normalize_time_feature(
      np.array(self.dti.strftime("%U").astype(int), dtype=np.float32), # weeks
      max_val=51.0,
    )

  def _get_holidays(self):
    dti_series = self.dti.to_series()
    # Parallelize the computation of distance to holidays
    def compute_distances(series, holiday):
      return series.apply(_distance_to_holiday(holiday)).values
    
    distances = Parallel(
      n_jobs=len(HOLIDAYS),
    )(
      delayed(compute_distances)(dti_series, holiday) for holiday in HOLIDAYS
    )
    hol_variates = np.vstack(distances) # (num_holiday, num_time_steps)
    
    # # hol_variates is (num_holiday, num_time_steps), the normalization should be
    # # performed in the num_time_steps dimension.
    # return StandardScaler().fit_transform(hol_variates.T).T
    return hol_variates

  def get_covariates(self):
    """Get all time covariates."""
    moh = self._minute_of_hour().reshape(1, -1)
    hod = self._hour_of_day().reshape(1, -1)
    dom = self._day_of_month().reshape(1, -1)
    dow = self._day_of_week().reshape(1, -1)
    doy = self._day_of_year().reshape(1, -1)
    moy = self._month_of_year().reshape(1, -1)
    woy = self._week_of_year().reshape(1, -1)

    all_covs = [
        moh,
        hod,
        dom,
        dow,
        doy,
        moy,
        woy,
    ]
    columns = ["moh", "hod", "dom", "dow", "doy", "moy", "woy"]
    if self.holiday:
      hol_covs = self._get_holidays()
      all_covs.extend(hol_covs)
      columns += [f"hol_{i}" for i in range(len(HOLIDAYS))]

    if self.sincos_embed_dim != 0:

      col_sincos_embed = []
      colname_sincos_embed = [] 
      for cov, col in zip(all_covs, columns):
        emb = get_sin_cos_embedding(cov, self.sincos_embed_dim)
        col_sincos_embed.append(emb.T) # (emb_dim, T)
        colname_sincos_embed.extend(
            [f"{col}_sincos{i}" for i in range(self.sincos_embed_dim)]
        )

      all_covs = col_sincos_embed
      columns = colname_sincos_embed

    return pd.DataFrame(
        data=np.vstack(all_covs).transpose(),
        columns=columns,
        index=self.dti,
    )