import utils
import models

import pandas as pd

import torch
import torch.nn as nn
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm

import networkx as nx
import random

import time

from matplotlib import pyplot as plt

from generate_graph import gen_graph
from generate_paths import gen_paths_agents, gen_paths_agents_2, gen_paths_agents_3

import argparse


def parse_arguments():
    parser = argparse.ArgumentParser(description='Set parameters for the program.')

    parser.add_argument('--method', type=str, default='VAIP')
    parser.add_argument('--eps', type=float, default=0.05)
    parser.add_argument('--lr', type=float, default=0.00002)
    parser.add_argument('--BS', type=int, default=200)
    parser.add_argument('--alpha_kl', type=float, default=0.001)
    parser.add_argument('--seed_n', type=int, default=0)
    parser.add_argument('--latent_dim', type=int, default=2)
    parser.add_argument('--n_epochs', type=int, default=400)
    
    return parser.parse_args()


# Parsing arguments
args = parse_arguments()

method = args.method

eps = args.eps
lr = args.lr
BS = args.BS

alpha_kl = args.alpha_kl
seed_n = args.seed_n
latent_dim = args.latent_dim
n_epochs = args.n_epochs

mm = method + '_'
if method=='VAIP':
    mm = ''

dev = 'cpu'

suffix = f'{mm}ship_{eps}_{lr}_{BS}_{alpha_kl}_{seed_n}_{latent_dim}_{n_epochs}'

output_path = f'./outputs/'
model_path = f'./saved_models/'

torch.manual_seed(seed_n)
np.random.seed(seed_n)
random.seed(seed_n)


M_indices = utils_ww.get_M_indices(N)

data_dir = f'./../dataSP/data_warcraft/{N}x{N}/'

data_suffix = "maps"
train_prefix = "train"
val_prefix = "val"

train_inputs = np.load(os.path.join(data_dir, train_prefix + "_" + data_suffix + ".npy")).astype('float')
train_weights = np.load(os.path.join(data_dir, train_prefix + "_vertex_weights.npy"))
train_labels = np.load(os.path.join(data_dir, train_prefix + "_shortest_paths.npy"))

val_inputs = np.load(os.path.join(data_dir, val_prefix + "_" + data_suffix + ".npy")).astype('float')
val_weights = np.load(os.path.join(data_dir, val_prefix + "_vertex_weights.npy"))
val_labels = np.load(os.path.join(data_dir, val_prefix + "_shortest_paths.npy"))

train_inputs = train_inputs.transpose(0,3,1,2)
val_inputs = val_inputs.transpose(0,3,1,2)

mean, std = (
    np.mean(train_inputs, axis=(0, 2, 3), keepdims=True),
    np.std(train_inputs, axis=(0, 2, 3), keepdims=True),
)

del(train_inputs)

val_inputs -= mean
val_inputs /= std

train_paths = []
for i in tqdm(range(0, N_train)):
    train_paths.append(utils_ww.get_path_nodes(M_indices, train_labels[i]))

val_paths_nodes = []
for i in tqdm(range(0, 1000)):
    val_paths_nodes.append(utils_ww.get_path_nodes(M_indices, val_labels[i]))
    
n_edges = len(M_indices)

train_paths_edges = [[(p[i], p[i+1]) for i in range(len(p)-1)] for p in train_paths]
edge_to_index = {tuple(edge): i for i, edge in enumerate(M_indices.numpy())}
paths_train = np.zeros((N_train, n_edges), dtype=int)

for i, edge_list in enumerate(train_paths_edges):
    for edge in edge_list:
        if edge in edge_to_index:
            index = edge_to_index[edge]
            paths_train[i, index] = 1
            
paths_train_torch = torch.tensor(paths_train, dtype=torch.float32)


opt = torch.optim.RMSprop(model.parameters(), lr, weight_decay=1e-6)

node_pairs = np.array([(0, N**2-1)])

val_inputs_tensor = torch.from_numpy(val_inputs).float()

imgs_tot = np.load(f'./../dataSP/data_warcraft/{N}x{N}/train_maps.npy', mmap_mode='r').astype('int')[:N_train]


# Encoder maps the observed trajectory + start and end locations (lat,lon) to latent space (mu1, std1, mu2, std)
encoder = models.Encoder(input_size=n_edges + 4, output_size=latent_dim, hl_sizes=[1000, 1000])  
encoder = encoder.to(dev)

# Simple version with noise as input
#encoder = models.Encoder(input_size=2, output_size=2, hl_sizes=[1024, 1024])  
#encoder = encoder.to(dev)

# Decoder maps the latent space (mu1, std1, mu2, std) to edges' cost
decoder = models.ANN(input_size=latent_dim, output_size=n_edges, hl_sizes=[1000, 1000])  
if method == 'VAE':
    decoder = models.ANN2(input_size=latent_dim, output_size=n_edges, hl_sizes=[1000, 1000])  
decoder = decoder.to(dev)

#decoder2 = models.ANN2(input_size=latent_dim, output_size=n_edges, hl_sizes=[1000, 1000])  
#decoder2 = decoder2.to(dev)


opt = torch.optim.RMSprop(encoder.parameters(), lr=lr, weight_decay=1e-7)
opt_decoder = torch.optim.RMSprop(decoder.parameters(), lr, weight_decay=1e-7)
#opt_decoder2 = torch.optim.Adam(decoder2.parameters(), lr, weight_decay=1e-7)

heurist_M = torch.tensor(20*distance_matrix, dtype=torch.float32)
heurist_edges = heurist_M[edges[:,0], edges[:,1]]

    
alg = 'dij'
    
solver_sp = None
if alg == 'dij':
    solver_sp = utils.Dijkstra(n_nodes, edges)
elif alg == 'astar':
    solver_sp = utils.Astar(n_nodes, edges, heurist_M.numpy())
else:
    exit()
    
loss_vae = nn.BCEWithLogitsLoss()


loss_eval_list = []
kl_div_eval_list = []
iou_eval_list = []

latent_vectors_ev = []
agent_ev = []

ev_step = 0




for epoch in range(n_epochs):
    
    idcs_order = torch.randint(0, N_train, (N_train,))
    
    start_time = time.time()

    for it in range(N_train//BS):
        idcs_batch = idcs_order[it*BS:(it+1)*BS] 
                
            
        paths_batch = paths_train_torch[idcs_batch].to(dev)
        se_points_batch = se_points_torch[idcs_batch].to(dev)
        
        if method in ['VAIP', 'VAE']:
            input_encoder = torch.hstack((paths_batch, se_points_batch))        
            z_mu, z_logvar, z_sample = encoder(input_encoder)
            edges_pred = decoder(z_sample)
        elif method == 'Perturbed':
            edges_pred = decoder(torch.ones(BS, latent_dim))

        if method in ['Perturbed','VAIP']:
            edges_eps = (edges_pred + eps*torch.randn_like(edges_pred)).clamp(0.0001)
            path_eps = torch.tensor(solver_sp.batched_solver(
                edges_eps.detach().numpy(), se_nodes_train[idcs_batch]), dtype=torch.float32)
        elif method == 'VAE':
            path_eps = edges_pred
      
        unique_path_pred = torch.unique(path_eps, dim=0)
        num_unique_path_pred = unique_path_pred.size(0)

        unique_path = torch.unique(paths_batch, dim=0)
        num_unique_path = unique_path.size(0)

        if method in ['Perturbed','VAIP']:
            loss_per_sample = ((paths_batch*edges_pred).sum(-1) - (path_eps*edges_pred).sum(-1))
        elif method == 'VAE':
            loss_per_sample = loss_vae(path_eps, paths_batch)
            
        
        loss = loss_per_sample.mean()
            

        if method in ['VAIP', 'VAE']:
            kl_divergence = (-0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp()))
        elif method == 'Perturbed':
            kl_divergence = 0


        #edges_reg = ((edges_pred - 1.)**2).mean()
        
        
        total_loss = loss + alpha_kl*kl_divergence #+ 0.1*edges_reg

        

        union = np.where(path_eps.detach().numpy() + paths_batch.detach().numpy()>0,1,0).sum(1)
        inter = np.where((path_eps.detach().numpy() == 1) & (paths_batch.detach().numpy()==1),1,0).sum(1)

        iou_train = np.where(union==0, 0, inter/union).mean()

        
        #print(
        #    f'it: {it}',
        #    f'\t Loss Batch: {round(loss_per_sample.mean().detach().item(), 5)}',
        #    #f'\t LoLoss2ss2 Batch: {round(loss2_per_sample.mean().detach().item(), 5)}',
        #    f'\t KL Batch: {round(kl_divergence.detach().item(), 5)}',
        #    f'\t IOU Batch: {round(iou_train, 5)}',
        #    f'\t Norm Batch: {round(edges_reg.detach().item(), 5)}',
        #    f'\t Unique paths: Pred: {num_unique_path_pred}, Data: {num_unique_path}'
        #)

        opt_decoder.zero_grad()
        opt.zero_grad()
        total_loss.backward()
        opt_decoder.step()
        opt.step()  
        
        
        if it == 0:
    
            n_to_eval = N_paths - N_train
            with torch.no_grad():

                paths_batch_test = paths_test_torch[:n_to_eval].to(dev)
                se_points_test_ = se_points_test[:n_to_eval].to(dev)

                if method in ['VAIP', 'VAE']: 
                    input_encoder_test = torch.hstack((paths_batch_test, se_points_test_))
                    z_mu_ev, z_logvar_ev, z_sample_ev = encoder(input_encoder_test)
                    z_mu_np = z_mu_ev.numpy()
                    latent_vectors_ev.append(z_mu_np)                
                    edges_pred_test = decoder(z_mu_ev)
                elif method == 'Perturbed':
                    edges_pred_test = decoder(torch.ones(n_to_eval, latent_dim))

                #import pdb
                #pdb.set_trace()
                
                if method in ['VAIP', 'Perturbed']:
                    paths_theta_test = solver_sp.batched_solver(
                        edges_pred_test.detach().numpy().astype(np.float64), 
                        se_nodes_test)
                elif method == 'VAE':
                    paths_theta_test = torch.where(edges_pred_test<0.5, 0., 1.)
                
                if method in ['VAIP', 'Perturbed']:
                    loss_per_sample_ev = \
                    (paths_batch_test*edges_pred_test.numpy()).sum(-1) \
                    - (paths_theta_test*edges_pred_test.numpy()).sum(-1)
                elif method == 'VAE':
                    loss_per_sample_ev = loss_vae(paths_theta_test, paths_batch_test)

                if method in ['VAIP', 'VAE']: 
                    kl_divergence_ev = \
                    -0.5 * torch.sum(1 + z_logvar_ev - z_mu_ev.pow(2) - z_logvar_ev.exp())
                    kl_div_eval_list.append(kl_divergence_ev.item())
                elif method == 'Perturbed':
                    kl_divergence_ev = 0
                    kl_div_eval_list.append(0)
                
                loss_eval_list.append(loss_per_sample_ev.mean().item())
                

                union = np.where(paths_theta_test + true_paths_test[:n_to_eval]>0,1,0).sum(1)
                inter = np.where((paths_theta_test == 1) & (true_paths_test[:n_to_eval]==1),1,0).sum(1)

                iou = np.where(union==0, 0, inter/union).mean()

                iou_eval_list.append(iou)

                unique_path_pred = np.unique(paths_theta_test, axis=0)
                num_unique_path_pred = unique_path_pred.shape[0]
        
                unique_path = np.unique(true_paths_test[:n_to_eval], axis=0)
                num_unique_path = unique_path.shape[0]

                #print(edges_pred_test.mean(), (heurist_edges).mean())
                print(f'Validation: Epoch {epoch} \t It. {it} \t IOU: {iou} \t loss {loss_per_sample_ev.detach().mean().item()} \t Unique paths = ({num_unique_path_pred}, {num_unique_path})')
                
                ev_step = ev_step + 1



iou_eval_list_np = np.array(iou_eval_list)
loss_eval_list_np = np.array(loss_eval_list)

np.save(output_path + f'iou_{suffix}.npy', iou_eval_list_np)
np.save(output_path + f'loss_{suffix}.npy', loss_eval_list_np)
#np.save(output_path + f'agents_{suffix}.npy', agent_indicator_test)    

if method in ['VAIP', 'VAE']:
    latent_vectors_ev_np = np.array(latent_vectors_ev)
    np.save(output_path + f'latent_vector_{suffix}.npy', latent_vectors_ev_np)

torch.save(encoder.state_dict(), f'./saved_models/encoder_{suffix}.pkl')
torch.save(decoder.state_dict(), f'./saved_models/decoder_{suffix}.pkl')