from copy import copy
from dataclasses import dataclass
from singledispatchmethod import singledispatchmethod
from typing import Tuple

import funcy as f
import numpy as np
import pandas as pd
import shapefile
from shapely.geometry import shape, Point
from shapely.ops import transform

import rideshare_simulator.events as events
from rideshare_simulator.state import WorldState
from .routing import get_route
from .experiments import Experiment


class StateSummarizer(object):
    def __init__(self, interval=120.):
        super(StateSummarizer, self).__init__()
        self.interval = interval
        self.last_update = 0.

    def reducer(self, prev, update: Tuple[WorldState, events.Event]):
        state, event = update
        driver_df = state.as_df()
        if event.ts >= self.last_update + self.interval:
            self.last_update = event.ts
            prev.append(driver_df)
        return prev

    def finish(self, prev):
        return pd.concat(prev)

    def init(self):
        return []


class RequestSummarizer(object):
    def __init__(self, regions: dict, experiment: Experiment):
        self.regions = regions
        self.experiment = experiment

    @classmethod
    def from_shapefile(cls, shp_fname, *args, **kwargs):
        return cls(
            {x.record[0]: transform(lambda x, y: (y, x), shape(x.shape))
             for x in shapefile.Reader(shp_fname)},
            *args, **kwargs)

    def reducer(self, requests, update: Tuple[WorldState, events.Event]):
        state, event = update
        result = self.summarize_event(event, state)
        if not result is None:
            rider_id, update_dict = result
            requests[rider_id] = f.merge(requests.get(rider_id, dict()),
                                         update_dict)
        return requests

    def init(self):
        return dict()

    def finish(self, prev):
        return pd.DataFrame(list(prev.values()))

    @staticmethod
    def count_region(state: WorldState, name: str, shp):
        drivers = list(state.drivers.tree.intersection(
            shp.bounds, objects=True))
        in_region = (
            item.object for item in drivers
            if item.object.is_online
            and shp.contains(Point(item.bbox[0], item.bbox[1])))
        counts = count_by_capacity(state.ts, in_region)
        return f.walk_keys(lambda cnt: f"drivers_{name}_{cnt}", counts)

    @staticmethod
    def count_region_coarse(state: WorldState, name: str, shp):
        return {name: state.drivers.tree.count(shp.bounds)}

    @singledispatchmethod
    def summarize_event(self, event, state):
        return None

    @summarize_event.register(events.RequestEvent)
    def _(self, event: events.RequestEvent, state):
        counts = f.join(self.count_region_coarse(state, name, shp)
                        for (name, shp) in self.regions.items())
        counts = dict() if counts is None else counts
        online_drivers = f.lfilter(lambda d: d.is_online,
                                   state.drivers.values())
        summ = dict(ts=event.ts,
                    id=event.rider.id,
                    is_treated=self.experiment.is_treated(event),
                    src_lat=event.rider.src[0],
                    src_lng=event.rider.src[1],
                    dest_lat=event.rider.dest[0],
                    dest_lng=event.rider.dest[1],
                    v_no_purchase=event.rider.v_no_purchase,
                    km=get_route(event.ts, [event.rider.src, event.rider.dest]).total_kms,
                    slack=np.mean([driver.route.slack_time(state.ts)
                                   for driver in online_drivers
                                   if not driver.is_idle(state.ts)]))
        by_cap = count_by_capacity(state.ts, online_drivers)
        return (event.rider.id,
                f.merge(summ, counts,
                        f.walk_keys(lambda cnt: f"drivers_total_{cnt}", by_cap)))

    @summarize_event.register(events.OfferResponseEvent)
    def _(self, event: events.OfferResponseEvent, state):
        summ = dict(etd=event.offer.etd,
                    price=event.offer.price,
                    accepted=event.accepted)
        return (event.rider_id, summ)

    @summarize_event.register(events.DispatchEvent)
    def _(self, event: events.DispatchEvent, state):
        summ = dict(cost=event.insertion_cost,
                    is_match=len(state.drivers[event.driver_id]
                                 .route.remaining_riders(state.ts)) > 1)
        return (event.rider_id, summ)


def count_by_capacity(ts, drivers):
    by_cap = f.group_by(lambda d: d.capacity(ts), drivers)
    return f.walk_values(len, by_cap)


def init_summary():
    return dict(ts=0,
                available_drivers=0,
                available_capacity=0,
                online_drivers=0,
                online_capacity=0,
                current_riders=0,
                requests=0,
                accepted_requests=0,
                dispatches=0,
                revenue=0,
                cost=0)


def update_summary(summary: dict, update: Tuple[WorldState, events.Event]):
    state, event = update
    summary = copy(summary)
    summary["ts"] = event.ts
    available_drivers = state.get_available_drivers()
    summary["available_drivers"] = len(available_drivers)
    summary["online_drivers"] = len([d for d in state.drivers.values()
                                     if d.is_online])
    summary["online_capacity"] = sum(
        [d.max_capacity for d in state.drivers.values() if d.is_online])
    summary["current_riders"] = sum(
        [len(d.route.remaining_riders)
         for d in state.drivers.values()
         if d.is_online and not d.is_idle(state.ts)])
    summary["available_capacity"] = sum(driver.capacity
                                        for driver in available_drivers)
    if isinstance(event, events.RequestEvent):
        summary["requests"] += 1
    elif isinstance(event, events.DispatchEvent):
        summary["dispatches"] += 1
        summary["revenue"] += event.offer.price
        summary["cost"] += event.insertion_cost
    elif isinstance(event, events.OfferResponseEvent):
        summary["accepted_requests"] += event.accepted

    return summary
