import utils
import models

import torch
import torch.nn as nn
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm

import networkx as nx
import random

import time

from matplotlib import pyplot as plt


seed_n = 0

torch.manual_seed(seed_n)
np.random.seed(seed_n)
random.seed(seed_n)


def make_strongly_connected(G):
    """
    Ensures that a directed graph G is strongly connected by adding necessary edges.
    """
    while not nx.is_strongly_connected(G):
        components = list(nx.strongly_connected_components(G))
        
        if len(components) > 1:
            comp1 = random.choice(list(components[0]))
            comp2 = random.choice(list(components[1]))

            G.add_edge(comp1, comp2)
    
    return G

n_paths = 4000
N_train = 2000

n_nodes = 200
alpha = 0.07
beta = 0.5
G = nx.waxman_graph(n_nodes, alpha=alpha, beta=beta, seed=seed_n)
G = nx.DiGraph(G)

G = make_strongly_connected(G)

pos = nx.spring_layout(G) 



edges = np.array([e for e in G.edges.keys()])
points = np.array([p for p in pos.values()])
distance_matrix = cdist(points, points, metric='euclidean')
n_edges = len(edges)

print('N Edges:', n_edges)

indices = np.arange(points.shape[0])
grid_x, grid_y = np.meshgrid(indices, indices)
lat_lon_1 = points[grid_x.ravel()]
lat_lon_2 = points[grid_y.ravel()]
latloncomb = np.hstack((lat_lon_1, lat_lon_2))
latloncomb_torch = torch.tensor(latloncomb, dtype=torch.float32)

X = torch.randn((n_paths,2))
X_train = X[:N_train].clone()
X_test = X[N_train:].clone()

n_paths_a1 = int(0.5*n_paths)
n_paths_a2 = int(0.3*n_paths)
n_paths_a3 = n_paths - n_paths_a1 - n_paths_a2

agent_indicator = n_paths_a1*[1] + n_paths_a2*[2] + n_paths_a3*[3]

#agent/feature (cov is diagonal)
mu11, mu12 = 5, 10
mu21, mu22 = 5, 2
mu31, mu32 = 10, 10

sig11, sig12 = 0.5, 0.5
sig21, sig22 = 0.4, 0.4
sig31, sig32 = 1.5, 1.0


Z_a1 = np.hstack((
    sig11*torch.randn((n_paths_a1,1))+mu11, 
    sig12*torch.randn((n_paths_a1,1))+mu12
))

Z_a2 = np.hstack((
    sig21*torch.randn((n_paths_a2,1))+mu21, 
    sig22*torch.randn((n_paths_a2,1))+mu22
))

Z_a3 = np.hstack((
    sig31*(torch.randn((n_paths_a3,1))+mu31), 
    sig32*(torch.randn((n_paths_a3,1))+mu32)
))


Z = np.vstack((Z_a1, Z_a2, Z_a3))

idcs_shuf = np.random.randint(0, n_paths, (n_paths,))
Z = Z[idcs_shuf]
agent_indicator = np.array(agent_indicator)[idcs_shuf]

#map_to_edge = np.random.random((1, n_edges))

#mapping_latent_to_delta_cost = np.vstack(
#    (map_to_edge, map_to_edge)
#)

mapping_latent_to_delta_cost = np.random.random((2, n_edges))

delta_cost = Z@mapping_latent_to_delta_cost

#import pdb
#pdb.set_trace()

real_cost = np.expand_dims(distance_matrix[edges[:,0], edges[:,1]], 0) + delta_cost 
#+ np.random.randint(1,1000, (edges.shape[0],))/1000.

solver_sp = utils.Dijkstra(n_nodes, edges)


se_nodes = np.random.randint(0, n_nodes, (n_paths,2))

#se_nodes = np.repeat(np.array([[0,499]]), n_paths, 0)

#sono = np.random.randint(0,2,(n_paths,1))
#enno = np.random.randint(497,499,(n_paths,1))
#se_nodes = np.hstack((sono,enno))

#import pdb
#pdb.set_trace()

true_paths = np.zeros((n_paths,edges.shape[0]))
for i in tqdm(range(0,n_paths)):
    sn, en = np.random.randint(0, n_nodes, (2,))
    true_paths[i] = solver_sp.solve(
                        real_cost[i],
                        se_nodes[i,0],
                        se_nodes[i,1],
                    )

#import pdb
#pdb.set_trace()
    
sn_points = points[se_nodes[:,0]]
en_points = points[se_nodes[:,1]]
    
sn_train_torch = torch.tensor(sn_points[:N_train], dtype=torch.float32)
en_train_torch = torch.tensor(en_points[:N_train], dtype=torch.float32)
se_points_torch = torch.hstack((sn_train_torch, en_train_torch))

sn_test_torch = torch.tensor(sn_points[N_train:], dtype=torch.float32)
en_test_torch = torch.tensor(en_points[N_train:], dtype=torch.float32)
se_points_test = torch.hstack((sn_test_torch, en_test_torch))

true_paths_train = true_paths[:N_train]
true_paths_test = true_paths[N_train:]

se_nodes_train = se_nodes[:N_train]
se_nodes_test = se_nodes[N_train:]

paths_train_torch = torch.tensor(true_paths_train, dtype=torch.float32)
paths_test_torch = torch.tensor(true_paths_test, dtype=torch.float32)

agent_indicator_train = agent_indicator[:N_train]
agent_indicator_test = agent_indicator[N_train:]



n_noise = 1
eps = 0.001
lr = 0.0001
BS = 200
dev = 'cpu'

# Encoder maps the observed trajectory + start and end locations (lat,lon) to latent space (mu1, std1, mu2, std)
encoder = models.Encoder(input_size=n_edges + 4, output_size=2, hl_sizes=[1024, 1024])  
encoder = encoder.to(dev)

# Decoder maps the latent space (mu1, std1, mu2, std) to edges' cost
decoder = models.ANN(input_size=2, output_size=n_edges, hl_sizes=[1024, 1024])  
decoder = decoder.to(dev)

# ANN to predict the heuristics
ann_h = models.ANN(input_size=4, output_size=1, hl_sizes=[1024, 1024]) 


opt_decoder = torch.optim.RMSprop(decoder.parameters(), lr, weight_decay=1e-6)
opt_h = torch.optim.RMSprop(ann_h.parameters(), lr, weight_decay=1e-6)

heurist_M = torch.tensor(distance_matrix, dtype=torch.float32)
heurist_edges = heurist_M[edges[:,0], edges[:,1]]
    
alg = 'astar'
    
solver_sp = None
if alg == 'dij':
    solver_sp = utils.Dijkstra(n_nodes, edges)
elif alg == 'astar':
    solver_sp = utils.Astar(n_nodes, edges, heurist_M.numpy())
else:
    exit()

fw_solver = utils.FW(n_nodes, edges)
    
for epoch in range(60):
    
    idcs_order = torch.randint(0, N_train, (N_train,))
    
    start_time = time.time()

 
    for it in tqdm(range(N_train//BS)):
        idcs_batch = idcs_order[it*BS:(it+1)*BS] 
        
        X_batch = torch.randn((BS,2))
            
        paths_batch = paths_train_torch[idcs_batch].to(dev)
        se_points_batch = se_points_torch[idcs_batch].to(dev)
        
        se_nodes_batch = se_nodes_train[idcs_batch]
        
                
        heuristics_pred = ann_h(latloncomb_torch).reshape(n_nodes, n_nodes) + heurist_M
        
        #import pdb
        #pdb.set_trace()   
            
        delta_pred = decoder(X_batch)
        edges_pred = delta_pred + heurist_edges.unsqueeze(0)
        
        F_eps, path_eps = models.PerturbedMin.apply(
            edges_pred, 
            n_noise, n_edges, 
            se_nodes_batch, 
            eps, solver_sp)
        
        F_path = (paths_batch*edges_pred).sum(-1)
        
        loss_per_sample = F_path - F_eps
   

        loss = loss_per_sample.mean()
        
        if it%10==0:
            distances_pred = fw_solver.batched_solver(edges_pred.detach().numpy())        
            heuristics_pred_resh = heuristics_pred.unsqueeze(0).repeat(
                BS,1,1)
            solver_sp.M_euclidean = heuristics_pred_resh[0].detach().squeeze().numpy()


            loss_h_per_sample = ((torch.tensor(distances_pred) - heuristics_pred_resh)**2)
       
            loss = loss + loss_h_per_sample.mean()        
        
        #if alg == 'astar':
        #    loss = loss + .1*(edges_pred**2).mean()
            
        #loss = loss
        #kl_divergence = -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp())
        #loss = loss + 0.1*kl_divergence

        opt_decoder.zero_grad()
        opt_h.zero_grad()
        loss.backward()
        opt_decoder.step()
        opt_h.step()  

    n_to_eval = 2000
    with torch.no_grad():
                
        paths_batch_test = paths_test_torch[:n_to_eval].to(dev)
        se_points_test_ = se_points_test[:n_to_eval].to(dev)
        
        se_nodes_test_ = se_nodes_test[:n_to_eval]
        

        X_test_ = torch.randn((n_to_eval,2))
        
        
        delta_pred_test = decoder(X_test_)
        edges_pred_test = delta_pred_test + heurist_edges.unsqueeze(0)

        paths_theta_test = solver_sp.batched_solver(
            edges_pred_test.detach().numpy().astype(np.float64), 
            se_nodes_test_)

        union = np.where(paths_theta_test + true_paths_test[:n_to_eval]>0,1,0).sum(1)
        inter = np.where((paths_theta_test == 1) & (true_paths_test[:n_to_eval]==1),1,0).sum(1)

        iou = np.where(union==0, 0, inter/union).mean()

        print(edges_pred_test.mean(), (heurist_edges).mean())
        print(epoch, it, iou)