# Original file: https://github.com/ActivityWatch/aw-research/blob/master/aw_research/util.py

from typing import List, Tuple
from datetime import datetime, time, timedelta, timezone

import pandas as pd

from aw_core import Event


# typet5: (event: Event, timestamp: datetime) -> Tuple[Event, Event]
# type4py: (event: str, timestamp: datetime.datetime) -> str
def split_event_on_time(event: Event, timestamp: datetime) -> Tuple[Event, Event]:
    event1 = Event(**event)
    event2 = Event(**event)
    assert timestamp > event.timestamp
    event1.duration = timestamp - event1.timestamp
    event2.duration = (event2.timestamp + event2.duration) - timestamp
    event2.timestamp = timestamp
    assert event1.timestamp < event2.timestamp
    assert event.duration == event1.duration + event2.duration
    return event1, event2


# typet5: (timestamp: datetime) -> datetime
# type4py: (timestamp: datetime.datetime) -> datetime.datetime
def next_hour(timestamp: datetime) -> datetime:
    return datetime.combine(timestamp.date(), time(timestamp.hour)).replace(
        tzinfo=timestamp.tzinfo
    ) + timedelta(hours=1)


# typet5: () -> None
# type4py: () -> None
def test_next_hour() -> None:
    assert next_hour(datetime(2019, 1, 1, 6, 23)) == datetime(2019, 1, 1, 7)
    assert next_hour(datetime(2019, 1, 1, 23, 23)) == datetime(2019, 1, 2, 0)


# typet5: (event: Event) -> List[Event]
# type4py: (event: raiden_libs.events.Event) -> str
def split_event_on_hour(event: Event) -> List[Event]:
    hours_crossed = (event.timestamp + event.duration).hour - event.timestamp.hour
    if hours_crossed == 0:
        return [event]
    else:
        _next_hour = next_hour(event.timestamp)
        event1, event_n = split_event_on_time(event, _next_hour)
        return [event1, *split_event_on_hour(event_n)]


# typet5: () -> None
# type4py: () -> None
def test_split_event_on_hour() -> None:
    e = Event(
        timestamp=datetime(2019, 1, 1, 11, 30, tzinfo=timezone.utc),
        duration=timedelta(minutes=1),
    )
    assert len(split_event_on_hour(e)) == 1

    e = Event(
        timestamp=datetime(2019, 1, 1, 11, 30, tzinfo=timezone.utc),
        duration=timedelta(hours=2),
    )
    split_events = split_event_on_hour(e)
    assert len(split_events) == 3


# typet5: (dt: datetime) -> datetime
# type4py: (dt: datetime.date.time) -> str
def start_of_day(dt: datetime) -> datetime:
    today = dt.date()
    return datetime(today.year, today.month, today.day, tzinfo=timezone.utc)


# typet5: (dt: datetime) -> datetime
# type4py: (dt: datetime.datetime) -> float
def end_of_day(dt: datetime) -> datetime:
    return start_of_day(dt) + timedelta(days=1)


# typet5: (dt: datetime) -> datetime
# type4py: (dt: datetime.date) -> int
def get_week_start(dt: datetime) -> datetime:
    start = dt - timedelta(days=dt.date().weekday())
    return datetime.combine(start.date(), time(), tzinfo=dt.tzinfo)


# typet5: (dt1: datetime, dt2: datetime) -> bool
# type4py: (dt1: datetime.datetime, dt2: datetime.datetime) -> str
def is_in_same_week(dt1: datetime, dt2: datetime) -> bool:
    return get_week_start(dt1) == get_week_start(dt2)


# typet5: (start: datetime, end: datetime) -> List[Tuple[datetime, datetime]]
# type4py: (start: Optional[int], end: datetime.datetime) -> float
def split_into_weeks(start: datetime, end: datetime) -> List[Tuple[datetime, datetime]]:
    if start == end:
        return []
    elif is_in_same_week(start, end):
        return [(start, end)]
    else:
        split = get_week_start(start) + timedelta(days=7)
        return [(start, split)] + split_into_weeks(split, end)


# typet5: () -> None
# type4py: () -> None
def test_split_into_weeks() -> None:
    # tznaive
    split = split_into_weeks(
        datetime(2019, 1, 3, 12),
        datetime(2019, 1, 18, 0, 2),
    )
    for dtstart, dtend in split:
        print(dtstart, dtend)
    assert len(split) == 3

    # tzaware
    split = split_into_weeks(
        datetime(2019, 1, 3, 12, tzinfo=timezone.utc),
        datetime(2019, 1, 18, 0, 2, tzinfo=timezone.utc),
    )
    for dtstart, dtend in split:
        print(dtstart, dtend)
    assert len(split) == 3


# typet5: (start: datetime, end: datetime) -> List[Tuple[datetime, datetime]]
# type4py: (start: datetime.datetime, end: datetime.datetime) -> Tuple[int, int]
def split_into_days(start: datetime, end: datetime) -> List[Tuple[datetime, datetime]]:
    if start == end:
        return []
    elif start.date() == end.date():
        return [(start, end)]
    else:
        split = datetime.combine(start.date(), time()) + timedelta(days=1)
        return [(start, split)] + split_into_days(split, end)


# typet5: () -> None
# type4py: () -> None
def test_split_into_days() -> None:
    split = split_into_days(datetime(2019, 1, 3, 12), datetime(2019, 1, 6, 0, 2))
    for dtstart, dtend in split:
        print(dtstart, dtend)
    assert len(split) == 4


# typet5: (events: List[Event]) -> None
# type4py: (events: List[aw_core.Event]) -> None
def verify_no_overlap(events: List[Event]):
    try:
        assert all(
            [
                e1.timestamp + e1.duration <= e2.timestamp
                for e1, e2 in zip(events[:-1], events[1:])
            ]
        )
    except AssertionError as e:
        n_overlaps = 0
        total_overlap = timedelta()
        for e1, e2 in zip(events[:-1], events[1:]):
            if e1.timestamp + e1.duration > e2.timestamp:
                overlap = (e1.timestamp + e1.duration) - e2.timestamp
                n_overlaps += 1
                total_overlap += overlap
        print(
            f"[WARNING] Found {n_overlaps} events overlapping, totalling: {total_overlap}"
        )


# typet5: (events: List[Event], category: str) -> float
# type4py: (events: Callable, category: float) -> dict
def categorytime_per_day(events, category):
    events = [e for e in events if category in e.data["$category_hierarchy"]]
    if not events:
        raise Exception("No events to calculate on")
    ts = pd.Series(
        [e.duration.total_seconds() / 3600 for e in events],
        index=pd.DatetimeIndex([e.timestamp for e in events]).tz_localize(None),
    )
    return ts.resample("1D").apply("sum")


# typet5: (events: List[Event], category: str, day: int) -> int
# type4py: (events: str, category: str, day: datetime.datetime) -> int
def categorytime_during_day(
    events: List[Event], category: str, day: datetime
) -> pd.Series:
    events = [e for e in events if category in e.data["$category_hierarchy"]]
    events = [e for e in events if e.timestamp > day]
    _events = []
    for e in events:
        _events.extend(split_event_on_hour(e))
    events = _events
    ts = pd.Series(
        [e.duration.total_seconds() / 3600 for e in events],
        index=pd.DatetimeIndex([e.timestamp for e in events]),
    )
    return ts.resample("1H").apply("sum")
