import utils
import models

import torch
import torch.nn as nn
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm

import networkx as nx
import random

import heapq

import time

from matplotlib import pyplot as plt


seed_n = 0

torch.manual_seed(seed_n)
np.random.seed(seed_n)
random.seed(seed_n)


def make_strongly_connected(G):
    """
    Ensures that a directed graph G is strongly connected by adding necessary edges.
    """
    while not nx.is_strongly_connected(G):
        components = list(nx.strongly_connected_components(G))
        
        if len(components) > 1:
            comp1 = random.choice(list(components[0]))
            comp2 = random.choice(list(components[1]))

            G.add_edge(comp1, comp2)
    
    return G



def find_edge_vectors_within_threshold(
    edges, costs, graph, source, target, threshold_factor=1.3, weight='weight'):
    
    for edge, weight in zip(edges, costs):
        graph[edge[0]][edge[1]]['weight'] = weight
    
    edge_to_index = {tuple(edge): idx for idx, edge in enumerate(map(tuple, edges))}
    
    optimal_path_length = nx.shortest_path_length(graph, source, target, weight=weight)
    threshold_cost = threshold_factor * optimal_path_length
    
    queue = [(0, source, [source], np.zeros(len(edges), dtype=int))]
    
    valid_edge_vectors = []
    
    while queue:
        current_cost, current_node, current_path, used_edges = heapq.heappop(queue)
        
        if current_node == target and current_cost <= threshold_cost:
            valid_edge_vectors.append(used_edges)
            continue
        
        for neighbor in graph.neighbors(current_node):
            if neighbor in current_path:
                continue 
            
            edge = (current_node, neighbor)
            if edge not in edge_to_index:
                continue
            
            edge_weight = graph[current_node][neighbor].get(weight, 1)
            new_cost = current_cost + edge_weight
            
            if new_cost <= threshold_cost:
                new_path = current_path + [neighbor]
                new_used_edges = used_edges.copy()
                new_used_edges[edge_to_index[edge]] = 1
                heapq.heappush(queue, (new_cost, neighbor, new_path, new_used_edges))
    
    return np.array(valid_edge_vectors)



n_paths = 4000
N_train = 2000

n_nodes = 200
alpha = 0.07
beta = 0.5
G = nx.waxman_graph(n_nodes, alpha=alpha, beta=beta, seed=seed_n)
G = nx.DiGraph(G)

G = make_strongly_connected(G)

pos = nx.spring_layout(G) 

edges = np.array([e for e in G.edges.keys()])
points = np.array([p for p in pos.values()])
distance_matrix = cdist(points, points, metric='euclidean')
n_edges = len(edges)

print('N Edges:', n_edges)

X = torch.randn((n_paths,2))
X_train = X[:N_train].clone()
X_test = X[N_train:].clone()

n_paths_a1 = int(0.5*n_paths)
n_paths_a2 = int(0.3*n_paths)
n_paths_a3 = n_paths - n_paths_a1 - n_paths_a2

agent_indicator = n_paths_a1*[1] + n_paths_a2*[2] + n_paths_a3*[3]

#agent/feature (cov is diagonal)
mu11, mu12 = 5, 10
mu21, mu22 = 5, 2
mu31, mu32 = 10, 10

sig11, sig12 = 0.5, 0.5
sig21, sig22 = 0.4, 0.4
sig31, sig32 = 1.5, 1.0


Z_a1 = np.hstack((
    sig11*torch.randn((n_paths_a1,1))+mu11, 
    sig12*torch.randn((n_paths_a1,1))+mu12
))

Z_a2 = np.hstack((
    sig21*torch.randn((n_paths_a2,1))+mu21, 
    sig22*torch.randn((n_paths_a2,1))+mu22
))

Z_a3 = np.hstack((
    sig31*(torch.randn((n_paths_a3,1))+mu31), 
    sig32*(torch.randn((n_paths_a3,1))+mu32)
))


Z = np.vstack((Z_a1, Z_a2, Z_a3))

idcs_shuf = np.random.randint(0, n_paths, (n_paths,))
Z = Z[idcs_shuf]
agent_indicator = np.array(agent_indicator)[idcs_shuf]

#map_to_edge = np.random.random((1, n_edges))

#mapping_latent_to_delta_cost = np.vstack(
#    (map_to_edge, map_to_edge)
#)

mapping_latent_to_delta_cost = np.random.random((2, n_edges))

delta_cost = Z@mapping_latent_to_delta_cost

#import pdb
#pdb.set_trace()

real_cost = np.expand_dims(distance_matrix[edges[:,0], edges[:,1]], 0) + delta_cost 
#+ np.random.randint(1,1000, (edges.shape[0],))/1000.

solver_sp = utils.Dijkstra(n_nodes, edges)


se_nodes = np.random.randint(0, n_nodes, (n_paths,2))

#se_nodes = np.repeat(np.array([[0,499]]), n_paths, 0)

#sono = np.random.randint(0,2,(n_paths,1))
#enno = np.random.randint(497,499,(n_paths,1))
#se_nodes = np.hstack((sono,enno))

#import pdb
#pdb.set_trace()

true_paths = np.zeros((n_paths,edges.shape[0]))
for i in tqdm(range(0,n_paths)):
    sn, en = np.random.randint(0, n_nodes, (2,))
    true_paths[i] = solver_sp.solve(
                        real_cost[i],
                        se_nodes[i,0],
                        se_nodes[i,1],
                    )

#import pdb
#pdb.set_trace()
    
sn_points = points[se_nodes[:,0]]
en_points = points[se_nodes[:,1]]
    
sn_train_torch = torch.tensor(sn_points[:N_train], dtype=torch.float32)
en_train_torch = torch.tensor(en_points[:N_train], dtype=torch.float32)
se_points_torch = torch.hstack((sn_train_torch, en_train_torch))

sn_test_torch = torch.tensor(sn_points[N_train:], dtype=torch.float32)
en_test_torch = torch.tensor(en_points[N_train:], dtype=torch.float32)
se_points_test = torch.hstack((sn_test_torch, en_test_torch))

true_paths_train = true_paths[:N_train]
true_paths_test = true_paths[N_train:]

se_nodes_train = se_nodes[:N_train]
se_nodes_test = se_nodes[N_train:]

paths_train_torch = torch.tensor(true_paths_train, dtype=torch.float32)
paths_test_torch = torch.tensor(true_paths_test, dtype=torch.float32)

agent_indicator_train = agent_indicator[:N_train]
agent_indicator_test = agent_indicator[N_train:]



n_noise = 1
eps = 0.001
lr = 0.00005
BS = 32
dev = 'cpu'

# Encoder maps the observed trajectory + start and end locations (lat,lon) to latent space (mu1, std1, mu2, std)
encoder = models.Encoder(input_size=n_edges + 4, output_size=2, hl_sizes=[1024, 1024])  
encoder = encoder.to(dev)

# Decoder maps the latent space (mu1, std1, mu2, std) to edges' cost
decoder = models.ANN(input_size=2, output_size=n_edges, hl_sizes=[1024, 1024])  
decoder = decoder.to(dev)



opt = torch.optim.RMSprop(encoder.parameters(), lr=lr, weight_decay=1e-6)
opt_decoder = torch.optim.RMSprop(decoder.parameters(), lr, weight_decay=1e-6)

heurist_M = torch.tensor(distance_matrix, dtype=torch.float32)
heurist_edges = heurist_M[edges[:,0], edges[:,1]]
    
alg = 'dij'
    
solver_sp = None
if alg == 'dij':
    solver_sp = utils.Dijkstra(n_nodes, edges)
elif alg == 'astar':
    solver_sp = utils.Astar(n_nodes, edges, heurist_M.numpy())
else:
    exit()
    
for epoch in range(60):
    
    idcs_order = torch.randint(0, N_train, (N_train,))
    
    start_time = time.time()

    for it in tqdm(range(N_train//BS)):
        idcs_batch = idcs_order[it*BS:(it+1)*BS] 
                
            
        paths_batch = paths_train_torch[idcs_batch].to(dev)
        se_points_batch = se_points_torch[idcs_batch].to(dev)
        
        input_encoder = torch.hstack((paths_batch, se_points_batch))
        
        z_mu, z_logvar, z_sample = encoder(input_encoder)
                
        delta_pred = decoder(z_sample)
        edges_pred = delta_pred + heurist_edges.unsqueeze(0)
        
        loss = 0
        
        for bb in range(BS):
            some_paths = find_edge_vectors_within_threshold(
                edges, edges_pred[bb], G, 
                se_nodes_train[idcs_batch][bb,0], se_nodes_train[idcs_batch][bb,1], 
                threshold_factor=1.2, weight='weight')
            some_paths_torch = torch.tensor(some_paths, dtype=torch.float32)
            cost_some_paths = ((edges_pred[bb].unsqueeze(0))*some_paths_torch).sum(1)
            exp_cost_some_paths = torch.exp(-cost_some_paths)
            numerator = ((exp_cost_some_paths.unsqueeze(1))*some_paths_torch).sum(0)      
            p_edge = (numerator/(exp_cost_some_paths.sum()))
            #loss_sample = ((paths_batch[bb] - p_edge)**2).sum()

            loss_sample = -(p_edge*edges_pred[bb]).sum() + (paths_batch[bb]*edges_pred[bb]).sum()
            loss = loss + loss_sample/BS
            #import pdb
            #pdb.set_trace()

        
        #loss = loss_per_sample.mean()        
        #if alg == 'astar':
        #    loss = loss + .1*(edges_pred**2).mean()
        
        #kl_divergence = -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp())
        #loss = loss #+ 0.1*kl_divergence

        opt_decoder.zero_grad()
        opt.zero_grad()
        loss.backward()
        opt_decoder.step()
        opt.step()  
        
        
    
    n_to_eval = 2000
    with torch.no_grad():
        
        paths_batch_test = paths_test_torch[:n_to_eval].to(dev)
        se_points_test_ = se_points_test[:n_to_eval].to(dev)
        
        input_encoder_test = torch.hstack((paths_batch_test, se_points_test_))
        
        z_mu_ev, z_logvar_ev, z_sample_ev = encoder(input_encoder_test)
        
        z_mu_np = z_mu_ev.numpy()

        
        color_map = {1: 'red', 2: 'green', 3: 'blue'}
        colors = np.array([color_map[agent] for agent in agent_indicator_test])
        
        plt.figure(figsize=(8, 6), dpi=200)
        plt.scatter(z_mu_np[:, 0], z_mu_np[:, 1], c=colors, alpha=0.6)
               
        plt.title("Scatter Plot of Latent Dimenions (2d)")
        plt.xlabel("Z1")
        plt.ylabel("Z2")

        plt.savefig(f"./outputs/latent_plot_{epoch}.png")
        
        
        delta_pred_test = decoder(z_mu_ev)
        edges_pred_test = delta_pred_test + heurist_edges.unsqueeze(0)

        paths_theta_test = solver_sp.batched_solver(
            edges_pred_test.detach().numpy().astype(np.float64), 
            se_nodes_test)

        union = np.where(paths_theta_test + true_paths_test[:n_to_eval]>0,1,0).sum(1)
        inter = np.where((paths_theta_test == 1) & (true_paths_test[:n_to_eval]==1),1,0).sum(1)

        iou = np.where(union==0, 0, inter/union).mean()

        print(edges_pred_test.mean(), (heurist_edges).mean())
        print(epoch, it, iou)