import argparse
import json
import os.path
from functools import partial

import dask.dataframe as dd
import networkx as nx
import numpy as np
import osmnx as ox
import pandas as pd
import geopandas as gpd
from dask.diagnostics import ProgressBar
from osmnx import settings, plot, distance

ProgressBar().register()

output_dir = '../data/graphs/'

parser = argparse.ArgumentParser()
parser.add_argument("--near_file", type=str, default='../data/near_filtered_salzburg')
parser.add_argument("--plot", action='store_true', help='Plot the graphs colored by the visitor counts (first 3)')
parser.add_argument("--force", action='store_true', help='Force query the osm graph')
parser.add_argument("--no_simplify", action='store_true', help='Do NOT simplify the OSM graph')
parser.add_argument("--poi_num_neighbours", default=5)
parser.add_argument("--poi_max_distance", default=80)

args = parser.parse_args()

if not os.path.exists(args.near_file):
    raise Exception("near_file does not exist, run filter_near.py first!")

ox.settings.log_console = True
ox.settings.use_cache = True

# Load boundary box
with open('bbox.json') as f:
    bbox = json.load(f)

# Load POI points
points_df = gpd.read_file('poi_points.geojson').dropna().set_index('name')
points_df["x"] = points_df["geometry"].x
points_df["y"] = points_df["geometry"].y

# Filter by BB
points_df = points_df.query(f'{bbox["west"]} <= `x` <= {bbox["east"]}')
points_df = points_df.query(f'{bbox["south"]} <= `y` <= {bbox["north"]}')

poi_names = points_df.index
print(poi_names)

# Load graph from file or download from OSM
graph_file = '../data/graphs/salzburg_osm.graphml'
if os.path.isfile(graph_file) and not args.force:
    G = ox.load_graphml(graph_file)
else:
    G = ox.graph_from_bbox(**bbox, simplify=not args.no_simplify)
    G = nx.convert_node_labels_to_integers(G)
    ox.save_graphml(G, graph_file)

nodes, edges = ox.graph_to_gdfs(G)

# Near file has 216864843 lines
print("Reading near parquet")
near_df = dd.read_parquet(args.near_file)
print(f"Calculating Map Matching for {len(near_df)} points")
near_df = near_df.map_partitions(
    lambda df: df.assign(node_id=distance.nearest_nodes(G, df['Lon of Visit'], df['Lat of Visit'])))

node_attributes = near_df[['timestep', 'node_id', 'Hashed Device ID']].drop_duplicates().groupby(
    ['timestep', 'node_id']).size().compute().unstack().fillna(0.0)
node_attributes.columns = node_attributes.columns.astype(str)
node_attributes = node_attributes.reindex(columns=nodes.index.values.astype(str), fill_value=0.0)

df = pd.read_csv('../data/dataset.csv', header=[0, 1], index_col=0, parse_dates=True)['venues']
df.index = pd.to_datetime(df.index, utc=True)
node_attributes.index = node_attributes.index.tz_localize(df.index.tz)
df = df.loc[node_attributes.index[node_attributes.index.isin(df.index)]]

new_edges = []
nns = []

# Find nearest nodes for each POI
for i, poi in enumerate(poi_names):
    x, y = points_df.loc[poi]["x"], points_df.loc[poi]['y']
    H = G.copy()
    for k in range(args.poi_num_neighbours):
        nn, dist = distance.nearest_nodes(H, x, y, return_dist=True)
        if dist < args.poi_max_distance or k == 0:
            new_edges.append((nn, i + len(G), dist))
            new_edges.append((i + len(G), nn, dist))
            H.remove_node(nn)
            nns.append(nn)
        else:
            break

# Add POIs as separate nodes and add edges to the nearest nodes
G.add_nodes_from(
    [(len(G) + id, {'x': points_df.iloc[id].x, 'y': points_df.iloc[id].y}) for id in range(len(poi_names))])
G.add_weighted_edges_from(new_edges, weight='length')

if args.plot:
    nx.set_node_attributes(G, 0, name='visitors')
    nx.set_node_attributes(G, {n: 10 for n in range(len(G) - len(poi_names), len(G))}, name='visitors')
    nx.set_node_attributes(G, {n: 5 for n in nns}, name='visitors')
    ec = ox.plot.get_node_colors_by_attr(G, attr='visitors')
    ox.plot_graph(G, node_color=list(ec), save=True, filepath=f'plots/graph_POIs.png')

df = df[poi_names]
df.columns = np.arange(len(G) - len(points_df), len(G)).astype(str)
num_pois = len(df)

df_normalized = df.copy()  # dd.from_pandas(df, chunksize=1000)
# node_attributes_normalized = dd.from_pandas(node_attributes, chunksize=1000)
node_attributes_normalized = node_attributes.copy()  #

node_attributes = node_attributes.join(df, how='inner').fillna(0.0)
os.makedirs(output_dir, exist_ok=True)
out_filename = output_dir + "node_attrs.parq"
node_attributes.to_parquet(out_filename)
print("finished writing " + out_filename)


def min_max_encode(data, min, max):
    return (data - min) / (max - min)


print("Normalizing data")

for col in df_normalized.columns:
    minimum, maximum = df_normalized[col].min(), df_normalized[col].max()
    df_normalized[col] = min_max_encode(df_normalized[col], minimum, maximum)

# FIXME using Dask leads to Segmentation Fault
minimum = node_attributes_normalized.melt().min()['value']
maximum = node_attributes_normalized.melt().max()['value']  # .compute()['value']
node_attributes_normalized = node_attributes_normalized.apply(partial(min_max_encode, min=minimum, max=maximum), axis=1)

node_attributes_normalized = node_attributes_normalized.join(df_normalized, how='inner').fillna(0.0)  # .compute()
out_filename = output_dir + "node_attrs_normalized.parq"
node_attributes_normalized.to_parquet(out_filename)
print("finished writing " + out_filename)

nodes, edges = ox.graph_to_gdfs(G)
# nx.get_edge_attributes(G, 'length').to_csv(output_dir + 'edge_weights.csv')
out_filename = output_dir + 'edge_weights.csv'
edges['length'].to_csv(out_filename)
print("finished writing " + out_filename)

if args.plot:
    H = G.copy()
    for n in G.nodes:
        if n < len(G) - len(poi_names):
            H.remove_node(n)

    for ts in node_attributes_normalized.index[:10]:
        nx.set_node_attributes(G, 0, name='visitors')
        nx.set_node_attributes(G,
                               {int(n): node_attributes_normalized[n][ts] for n in node_attributes_normalized.columns},
                               name='visitors')
        ec = list(ox.plot.get_node_colors_by_attr(G, attr='visitors'))
        fig, ax = ox.plot_graph(G, node_color=ec, show=False, edge_color="#777778")
        ox.plot_graph(H, ax=ax, node_color=ec[len(nodes) - len(poi_names):], node_size=25, node_edgecolor='#D6D6D7',
                      edge_linewidth=0, save=True, filepath=f'plots/graph_{str(ts)}.png', show=False)
