from datetime import datetime
from functools import reduce
import itertools as it
import json
import os
import operator as op
import pickle
import pathlib

import pandas as pd
import funcy as f
from tqdm import tqdm

from rideshare_simulator.dispatch.dispatch_policy import \
    CheapestDispatchPolicy, DispatchExperimentPolicy
from rideshare_simulator.dispatch.planner import ShortestPathPlanner
from rideshare_simulator.generators.driver_generator import UniformDriverOnlineGenerator
from rideshare_simulator.generators.rider_generator import UniformRequestGenerator
from rideshare_simulator.generators.taxi import \
    NYCTaxiRequestGenerator, NYCTaxiDriverOnlineGenerator
from rideshare_simulator.rider import MaxUtilityRider
from rideshare_simulator.pricing.policy import \
    ConstantFactorPricingPolicy, PricingExperimentPolicy
from rideshare_simulator.simulator import Simulator
from rideshare_simulator.state import WorldState
from rideshare_simulator.summary import RequestSummarizer, StateSummarizer
from rideshare_simulator.experiments import Experiment, SwitchbackExperiment

import sacred
ex = sacred.Experiment("rideshare_simulation")
ex.add_config("config/default.yaml")

min_latlng = (37.736927, -122.512273)
max_latlng = (37.816437, -122.375974)


@ex.capture
def request_generator(rider, init, cost_per_km, cost_per_sec):
    if rider["generator"] == "uniform":
        rider_params = dict(mean_wtp_per_sec=rider['mean_wtp_per_sec'],
                            cost_per_km=cost_per_km,
                            cost_per_sec=cost_per_sec)
        gen = UniformRequestGenerator(
            rider['mean_time'],
            min_latlng, max_latlng,
            rider_ctor=MaxUtilityRider.lognormal_utility,
            rider_params=rider_params)
    elif rider["generator"] == "taxi":
        gen = NYCTaxiRequestGenerator.from_file(
            rider['trips_fname'],
            rider['shp_fname'],
            mean_wtp_per_sec=rider['mean_wtp_per_sec'],
            sigma=rider['sigma'],
            rel_rate=rider['rel_rate'],
            method=rider['method'])
        gen.set_min_ts(init["min_ts"])
    else:
        raise NotImplementedError()
    return gen


@ex.capture
def driver_generator(driver, init):
    if driver["generator"] == "uniform":
        gen = UniformDriverOnlineGenerator(
            driver["mean_time"], min_latlng, max_latlng,
            mean_shift_length=driver["mean_shift_length"],
            capacity=driver["capacity"])
    elif driver["generator"] == "taxi":
        gen = NYCTaxiDriverOnlineGenerator.from_file(
            driver['trips_fname'],
            driver['shp_fname'],
            capacity=driver['capacity'],
            rel_rate=driver['rel_rate'],
            mean_shift_length=driver['mean_shift_length'],
            method=driver['method'])
        gen.set_min_ts(init["min_ts"])
    else:
        raise NotImplementedError()

    return gen


@ex.capture()
def my_experiment(experiment, _seed):
    configs = dict(ab=(Experiment, "rider_id"),
                   switchback=(SwitchbackExperiment, "ts"))
    ctor, attr = configs[experiment["type"]]
    return ctor(salt=experiment["salt"] + str(_seed),
                attrgetter=op.attrgetter(attr),
                **f.omit(experiment, ["type", "salt"]))


@ex.capture
def pricing_policy(pricing, dispatcher):
    A = ConstantFactorPricingPolicy(
        cost_fn=cost_fn(), dispatcher=dispatcher, **pricing["A"])
    B = ConstantFactorPricingPolicy(
        cost_fn=cost_fn(), dispatcher=dispatcher, **pricing["B"])
    return PricingExperimentPolicy(my_experiment(), A, B)


@ex.capture
def cost_fn(cost_per_km, cost_per_sec):
    return lambda route: cost_per_km * route.total_kms + \
        cost_per_sec * route.total_secs


@ex.capture
def dispatch_policy(dispatch, experiment, _seed):
    planner = ShortestPathPlanner(cost=cost_fn())
    A = CheapestDispatchPolicy(planner, **dispatch["A"])
    B = CheapestDispatchPolicy(planner, **dispatch["B"])
    return DispatchExperimentPolicy(my_experiment(), A, B)


@ex.capture
def summarizer(output, experiment):
    if output["type"] == "requests":
        if output.get("shapefile", None) is not None:
            summ = RequestSummarizer.from_shapefile(
                output["shapefile"], experiment=my_experiment())
        else:
            summ = RequestSummarizer(dict(), experiment=my_experiment())
    elif output["type"] == "state":
        summ = StateSummarizer(output.get("interval", 120))
    else:
        raise NotImplementedError()
    return summ


@ex.automain
def main(T, init, output, _config, _seed):
    summ = summarizer()
    pathlib.Path(output['dir']).mkdir(parents=True, exist_ok=True)
    with open(os.path.join(output['dir'], "config.json"), 'w') as file:
        json.dump(_config, file)

    dispatcher = dispatch_policy()
    pricer = pricing_policy(dispatcher=dispatcher)
    request_gen = request_generator()
    driver_gen = driver_generator()

    def save_checkpoint(state, summary, suffix=""):
        state_fname = os.path.join(output['dir'], f"state{suffix}.pkl")
        with open(state_fname, "wb") as file:
            pickle.dump(state, file)

        isumm_fname = os.path.join(
            output['dir'], f"intermediate-summary{suffix}.pkl")
        with open(isumm_fname, "wb") as file:
            pickle.dump(summary, file)

        summary_fname = os.path.join(output['dir'], f"summary{suffix}.csv")
        df = summ.finish(summary)
        df.to_csv(summary_fname, index=False)

        dispatch_logs_fname = os.path.join(
            output['dir'], f"dispatch{suffix}.csv")
        dispatcher.log_df().to_csv(dispatch_logs_fname, index=False)

        pricing_logs_fname = os.path.join(
            output['dir'], f"pricing{suffix}.csv")
        pricer.log_df().to_csv(pricing_logs_fname, index=False)

    sim = Simulator(request_gen, driver_gen, dispatcher, pricer)

    if init["state_pkl"] is not None:
        sim.state = WorldState.from_pickle(init["state_pkl"])

    sim_itor = it.islice(sim.run(), T)
    summary = summ.init()
    pbar = tqdm(enumerate(sim_itor), total=T, mininterval=0.1)
    for (i, (state, event)) in pbar:
        pbar.set_description(
            str(datetime.fromtimestamp(int(state.ts))))
        summary = summ.reducer(summary, (state, event))
        if i % int(output["checkpoint_interval"]) == 0:
            save_checkpoint(state, summary, suffix="-ckpt")

    save_checkpoint(state, summary)
