import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import sys

from scipy.spatial.distance import cdist

import utils
import models

sys.path.append('./../dataSP/')
import data_utils

from tqdm import tqdm

import time

import jgrapht
import jgrapht.generators as gen
import jgrapht.algorithms.shortestpaths as sp

seed_n = 0

path_data = './cabspotting_preprocessing/'

df_features = pd.read_csv(f'{path_data}features_per_trip_useful.csv')
df_trips = pd.read_csv(f'{path_data}full_useful_trips.csv')
df_edges = pd.read_csv(f'{path_data}graph_0010_080.csv')
df_nodes = pd.read_csv(f'{path_data}nodes_0010_080.csv')
df_nodes['node_sorted'] = df_nodes['node_id_new']

map_nodes = dict(np.array(df_nodes[['node_id','node_id_new']].drop_duplicates()))
df_trips['node_id_new'] = df_trips['node_id'].replace(map_nodes)

points = df_nodes.sort_values(by='node_sorted')[['node_lon','node_lat']]
distance_matrix = cdist(points, points, metric='euclidean')

#df_trips = df_trips[df_trips.groupby('trip_id_new').distance_km.transform('max')>2]
df_trips = df_trips.iloc[:300000]



# We want to train in part of the drivers
unique_drivers = df_trips['driver'].drop_duplicates()
selected_drivers = unique_drivers.sample(frac=0.7, random_state=seed_n)
df_trips_train = df_trips[df_trips['driver'].isin(selected_drivers)]
df_trips_test = df_trips[~df_trips['driver'].isin(selected_drivers)]

df_trips_train = df_trips_train[df_trips_train.groupby('trip_id_new').node_id_new.transform('nunique')>2]
df_trips_train = df_trips_train.sort_values(by=['driver','trip_id_new','date_time'])
df_features['day_of_Week'] = df_features['day_of_Week'].astype(int).map({
    0: 0, 1: 0, 2: 0, 3: 0, 4: 1,
    5: 2, 
    6: 3 })

df_features = pd.get_dummies(df_features, columns=['day_of_Week'])
df_features['time_start'] = (df_features['time_start'] - df_features['time_start'].min()) / (df_features['time_start'].max() - df_features['time_start'].min())

indices_trips = df_trips_train[['trip_id','driver','trip_id_new']].drop_duplicates()
df_features_train = indices_trips.merge(df_features, on=['trip_id','driver'], how='left')
df_features_train.iloc[:,-4:] = df_features_train.iloc[:,-4:].astype(int)
feats = ['day_of_Week_0','day_of_Week_1','day_of_Week_2','day_of_Week_3',
         'is_Holiday','time_start']
n_features = len(feats)
n_trips_train = len(df_features_train)


indices_trips_test = df_trips_test[['trip_id','driver','trip_id_new']].drop_duplicates()
df_features_test = indices_trips_test.merge(df_features, on=['trip_id','driver'], how='left')
df_features_test.iloc[:,-4:] = df_features_test.iloc[:,-4:].astype(int)



prior_M, edges_prior, M_indices = data_utils.get_prior_and_M_indices(
    df_nodes, df_edges)

assert (df_trips_train.trip_id_new.unique() == df_features_train.trip_id_new.unique()).all()

trip_ids = df_trips_train.trip_id_new.unique()

V = M_indices.max()+1

V = M_indices.max()+1

X_np = np.array(df_features_train[feats])
node_idx_sequence_trips = df_trips_train.groupby('trip_id_new')['node_id_new'].apply(list)

edges_seq_original = node_idx_sequence_trips.apply(
    lambda x: np.column_stack([x[:-1], x[1:]]))
start_nodes_original = node_idx_sequence_trips.apply(
    lambda x: x[0])
end_nodes_original = node_idx_sequence_trips.apply(
    lambda x: x[-1])

edges_idx_on_original = np.zeros((len(edges_seq_original), 
                                  len(M_indices)), dtype=int)
edges_seq_original_np = np.array(edges_seq_original)

N_train = len(edges_seq_original)



X_np_test = np.array(df_features_test[feats])
node_idx_sequence_trips_test = df_trips_test.groupby('trip_id_new')['node_id_new'].apply(list)

edges_seq_original_test = node_idx_sequence_trips_test.apply(
    lambda x: np.column_stack([x[:-1], x[1:]]))
start_nodes_original_test = node_idx_sequence_trips_test.apply(
    lambda x: x[0])
end_nodes_original_test = node_idx_sequence_trips_test.apply(
    lambda x: x[-1])

edges_seq_original_test = node_idx_sequence_trips_test.apply(
    lambda x: np.column_stack([x[:-1], x[1:]]))

edges_idx_on_original_test = np.zeros((len(edges_seq_original_test), 
                                  len(M_indices)), dtype=int)
edges_seq_original_np_test = np.array(edges_seq_original_test)

print('Processing Data')
for i in tqdm(range(len(edges_seq_original))):
    matching_indices = []
    for row in edges_seq_original_np[i]:
        idx = np.where(np.isin(M_indices[:,0], row[0])\
                       *np.isin(M_indices[:,1], row[1]))[0].item()
        edges_idx_on_original[i, idx] = 1

edges_seq_original = list(edges_seq_original)
node_idx_sequence_trips = list(node_idx_sequence_trips)

end_to_end_nodes_original = (np.vstack((
    np.array(start_nodes_original), 
    np.array(end_nodes_original))).T).astype(np.int32)



for i in tqdm(range(len(edges_seq_original_test))):
    matching_indices = []
    for row in edges_seq_original_np_test[i]:
        idx = np.where(np.isin(M_indices[:,0], row[0])\
                       *np.isin(M_indices[:,1], row[1]))[0].item()
        edges_idx_on_original_test[i, idx] = 1

edges_seq_original_test = list(edges_seq_original_test)
node_idx_sequence_trips_test = list(node_idx_sequence_trips_test)

end_to_end_nodes_original_test = (np.vstack((
    np.array(start_nodes_original_test), 
    np.array(end_nodes_original_test))).T).astype(np.int32)

paths_train_torch = torch.tensor(edges_idx_on_original, dtype=torch.float32).detach()
paths_test_torch = torch.tensor(edges_idx_on_original_test, dtype=torch.float32) .detach()

X = torch.tensor(X_np, dtype=torch.float32)
X_test = torch.tensor(X_np_test, dtype=torch.float32)

n_noise = 1
eps = 0.07
lr = 0.00001
BS = 200
n_edges = len(M_indices)
dev = 'cpu'
N_train = X.shape[0]

inp_s_model = X.shape[-1]
model = models.ANN(input_size=inp_s_model, output_size=len(M_indices), hl_sizes=[1024, 1024])  
model = model.to(dev)
opt = torch.optim.RMSprop(model.parameters(), lr, weight_decay=1e-6)


heurist_M = torch.tensor(1000*distance_matrix, dtype=torch.float32)
heurist_edges = heurist_M[M_indices[:,0], M_indices[:,1]]
    
solver_sp = utils.Dijkstra(V, M_indices.numpy())
#solver_sp = utils.Astar(V, M_indices.numpy(), heurist_M.numpy())

for epoch in range(30):
    
    idcs_order = torch.randint(0, N_train, (N_train,))
    
    
    
    start_time = time.time()

    for it in tqdm(range(N_train//BS)):
        idcs_batch = idcs_order[it*BS:(it+1)*BS] 
        
        #idcs_batch = torch.zeros((BS,), dtype=torch.long)
        
        X_batch = X[idcs_batch].to(dev)
        
        delta_pred = model(X_batch)
        edges_pred = delta_pred + heurist_edges.unsqueeze(0)
        
        F_eps, path_eps = models.PerturbedMin.apply(
            edges_pred, 
            n_noise, n_edges, 
            end_to_end_nodes_original[idcs_batch], 
            eps, solver_sp)
        
        loss_per_sample = (paths_train_torch[idcs_batch]*edges_pred).sum(-1) - F_eps

        loss = loss_per_sample.mean() + .01*(edges_pred**2).mean()

        opt.zero_grad()
        loss.backward()
        opt.step()  
        
    
        if it%30==0:
            n_to_eval = 1000
            with torch.no_grad():
                delta_pred_test = model(X_test[:n_to_eval])
                edges_pred_test = delta_pred_test + heurist_edges.unsqueeze(0)

                paths_theta_test = solver_sp.batched_solver(
                    edges_pred_test.detach().numpy().astype(np.float64), 
                    end_to_end_nodes_original_test)
                
                union = np.where(paths_theta_test + edges_idx_on_original_test[:n_to_eval]>0,1,0).sum(1)
                inter = np.where((paths_theta_test == 1) & (edges_idx_on_original_test[:n_to_eval]==1),1,0).sum(1)

                print(edges_pred_test.mean(), (heurist_edges).mean())
                print(epoch, it, (inter/union).mean())