import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision

import itertools
import functools

import argparse
import pickle

from math import sqrt

import numpy as np
import heapq
from functools import partial
from collections import namedtuple

import smooth_dp_utils
import utils
import data_utils

from tqdm import tqdm
import time

import utils_ww

import multiprocessing

import torch.optim.lr_scheduler as lr_scheduler





##################################################################################### 
#################################PARAMETERS########################################## 
##################################################################################### 
        
def parse_arguments():
    parser = argparse.ArgumentParser(description='Set parameters for the program.')
    
    parser.add_argument('--prefix', type=str, default='softmax', help='method')
    
    parser.add_argument('--suboptimals', type=int, default=0, help='noise in paths')
    
    parser.add_argument('--N', type=int, default=18, help='Grid size')
    parser.add_argument('--Vs', type=int, default=100, help='Nr sampling nodes')
    
    parser.add_argument('--N_train', type=int, default=10000, help='Nr sampling nodes')
       
    parser.add_argument('--dev', type=str, default='cpu', help='Device to use')
    parser.add_argument('--N_EPOCHS', type=int, default=100, help='N EPOCHS train')
    
    parser.add_argument('--beta', type=float, default=30., help='Beta Smooth')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate')
    parser.add_argument('--N_batches', type=int, default=30, help='N Batches in one Epoch')
    parser.add_argument('--bs_X', type=int, default=8, help='How many floyd warshalls in a batch')
    
    parser.add_argument('--seed_n', type=int, default=0)
          
    parser.add_argument('--load_model', type=int, default=0, help='Load previous model?')
        
    return parser.parse_args()

# Parsing arguments
args = parse_arguments()

# Assigning arguments to variables
prefix = args.prefix

suboptimals = args.suboptimals

seed_n = args.seed_n
torch.manual_seed(seed_n)
np.random.seed(seed_n)

N = args.N
Vs = args.Vs

N_train = args.N_train

dev = args.dev
N_EPOCHS = args.N_EPOCHS

beta_smooth = args.beta
lr = args.lr
N_batches = args.N_batches
bs_X = args.bs_X

load_model = args.load_model

ps_in_batch = 1

# Should we use nodes sampling during training?
Vs = int(args.Vs)

bool_scale = False
if Vs < N**2 and Vs > 0:
    bool_scale = True
else:
    Vs = N**2
        
epochs_wait = 10

print(f'RUNNING WITH {dev}')

M_indices = utils_ww.get_M_indices(N)

N_train = N_train #TO BACK TO 10000
N_val = 1000

data_dir = f'./data_warcraft/{N}x{N}/'

data_suffix = "maps"
train_prefix = "train"
val_prefix = "val"

train_inputs = np.load(os.path.join(
    data_dir, train_prefix + "_" + data_suffix + ".npy")).astype('float')[:N_train]
train_weights = np.load(
    os.path.join(data_dir, train_prefix + "_vertex_weights.npy"))[:N_train]

val_inputs = np.load(
    os.path.join(data_dir, val_prefix + "_" + data_suffix + ".npy")).astype('float')
val_weights = np.load(
    os.path.join(data_dir, val_prefix + "_vertex_weights.npy"))
val_labels = np.load(
    os.path.join(data_dir, val_prefix + "_shortest_paths.npy"))    

train_inputs = train_inputs.transpose(0,3,1,2)
val_inputs = val_inputs.transpose(0,3,1,2)

mean, std = (
    np.mean(train_inputs, axis=(0, 2, 3), keepdims=True),
    np.std(train_inputs, axis=(0, 2, 3), keepdims=True),
)
train_inputs -= mean
train_inputs /= std

val_inputs -= mean
val_inputs /= std

suffix_noise=''
if suboptimals:
    suffix_noise = '_noise_05'
    
c = 0
true_paths_nodes = []
with open(f'./data_warcraft/18x18/train_nodes_true{suffix_noise}.json', 'r') as file:
    for line in file:
        c = c+1
        elements = line.strip().split(',')
        sublist = [int(element) for element in elements]
        true_paths_nodes.append(sublist)
        if c>=30*N_train:
            break
       
        
print('----- Finish Generating and Processing Data -----')

##################################################################################### 
##################################################################################### 
#####################################################################################


##################################################################################### 
######################### MODEL LOAD OR CREATE ###################################### 
#####################################################################################

print('----- Model Load or Create -----')

model = utils_ww.CombRenset18(N**2, 3)
model = model.to(dev)

mse = nn.MSELoss(reduction='none')
bce = nn.BCELoss(reduction='none')

softmax = nn.Softmax(-1)
sigmo = nn.Sigmoid()

prior_M = torch.zeros((N**2,N**2))

def cross_entropy_cont(target, prediction):
    return -torch.sum(target * torch.log(prediction+0.00001), -1)

if suboptimals:
    criterion = torch.nn.KLDivLoss(reduction='none')
    def cross_entropy_cont(target, prediction):
        return criterion(torch.log(prediction + 0.00001), target).sum(-1)


model_path = f'saved_models/safw_ww_MtoM_{N}_{suffix_noise}_{beta_smooth}_{bs_X}_{lr}_{seed_n}.pkl'
print('Model path:', model_path)
if load_model:
    try:
        model = utils_ww.CombRenset18(N**2, 3)
        model.load_state_dict(
            torch.load(model_path, map_location=torch.device(dev)))
        model = model.to(dev)
        print('MODEL LOADED')
    except:
        print('FAILED TO LOAD')
        pass
else:
    print('MODEL CREATED')
    pass

model = model.to(dev)
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=10e-4)
print('MODEL ON ', next(model.parameters()).device)


print('----- Model Load or Create Finished -----')

##################################################################################### 
##################################################################################### 
#####################################################################################

train_inputs_tensor = torch.from_numpy(train_inputs).float()
val_inputs_tensor = torch.from_numpy(val_inputs).float()

M_indices = M_indices.to(dev)

elements, frequencies = utils.get_nodes_and_freqs(true_paths_nodes)

not_best_count_accum = 0
loss_batch_avg_best = torch.inf
perc_correct_best = 0.

scheduler = lr_scheduler.LinearLR(opt, start_factor=1.0, end_factor=0.01, total_iters=50)

for epochs in range(0,N_EPOCHS):

    loss_batch_avg = 0
    
    for batch in range(0, N_batches):  
     
        start_time = time.time()        
        if bool_scale:
            with torch.no_grad():
                selected_indexes, selected_trips, nodes_selected, nodes_excluded = utils.selected_trips_and_idx(
                    true_paths_nodes, M_indices, elements, frequencies, Vs, N**2)
                if selected_indexes == None:
                    continue

                selected_indexes = np.array(selected_indexes)              
                selected_indexes_imgs_all = selected_indexes//30              
                selected_indexes_imgs = list(set(list(selected_indexes_imgs_all)))            
            X_selected = train_inputs_tensor[selected_indexes_imgs]
            #print(round(time.time() - start_time, 3))
        else:
            selected_trips = true_paths_nodes
            X_selected = train_inputs_tensor
                 
        with torch.no_grad():
            idcs_batch = torch.randint(0, X_selected.shape[0], (bs_X,)).unsqueeze(1)

            idcs_imgs = idcs_batch[:,0].sort().values
            mask = np.isin(selected_indexes_imgs_all, idcs_imgs)
            idcs_paths = selected_indexes[mask]
    
        opt.zero_grad()
        
        X_batch = X_selected[idcs_imgs].to(dev)

        nodes_pred_batch = model(X_batch).clip(0.001)
        M_pred_batch = utils_ww.nodes_to_M_batch(nodes_pred_batch)
        
        if bool_scale:        
            M_Y_pred_selected, M_sigmaY_selected, M_indices_selected_mapped = utils.select_Ms_from_selected_idx_and_trips(
                M_pred_batch, 0.0*M_pred_batch, Vs, M_indices, nodes_excluded, nodes_selected, torch.tensor(beta_smooth), dev)            
        else:
            M_Y_pred_selected = M_pred_batch
            M_indices_selected_mapped = M_indices
            
            
        k_nodes = torch.arange(Vs)
        k_nodes_shufled = k_nodes[torch.randperm(Vs)]
        shuffle_k_dict = {int(k_nodes_shufled[i]):int(k_nodes[i]) for i in range(Vs)} 
        shuffle_k_inv_dict = {int(k_nodes[i]):int(k_nodes_shufled[i]) for i in range(Vs)}    
        
        # We want to remove bias of node ordering
        M_Y_pred_selected_shuf = M_Y_pred_selected[:,k_nodes_shufled][:, :, k_nodes_shufled]     
        M_indices_selected_mapped_shuf = M_indices_selected_mapped.clone()
        for key, value in shuffle_k_dict.items():
            M_indices_selected_mapped_shuf[M_indices_selected_mapped == key] = value           
        selected_trips_shuf = [[shuffle_k_dict[p] for p in sublist] for sublist in selected_trips] 

        probs_pred = smooth_dp_utils.smooth_floyd_warshall_batch_adapted_parallel(
            M_Y_pred_selected_shuf, M_indices_selected_mapped_shuf, dev, beta_smooth)
        
        idx_batch_paths = np.searchsorted(selected_indexes, idcs_paths)
        
        sel_imgs_idx = selected_indexes_imgs_all[idx_batch_paths]
        
        value_changes = np.diff(sel_imgs_idx, prepend=sel_imgs_idx[0]) != 0
        value_changes_int = value_changes.astype(int)
        sel_imgs_idx_sorted = np.cumsum(value_changes_int)
        
        m_inter_total = torch.zeros(bs_X, Vs, Vs, Vs)
        for i, p in zip(sel_imgs_idx_sorted, idx_batch_paths):  
            m_inter_total[i] += data_utils.get_m_inter(selected_trips_shuf[p], Vs, Vs)

        m_inter_total = (m_inter_total/m_inter_total.sum(-1).unsqueeze(-1)).to(dev)
        
        mask = ~torch.isnan(m_inter_total)
        true_paths_dist = m_inter_total[mask].reshape(-1, Vs)
        pred_paths_dist = probs_pred[mask].reshape(-1, Vs)
        loss_mse = cross_entropy_cont(true_paths_dist, pred_paths_dist).mean()
        
        loss_total = loss_mse 
        loss_total.backward()
        opt.step()
        
        loss_batch_avg += (loss_mse/N_batches).detach()
    
    with torch.no_grad():
        N_eval = 1000
        nodes_pred = model(val_inputs_tensor.to(dev)).clip(0.001)
        M_pred = utils_ww.nodes_to_M_batch(nodes_pred)

        path_pred = smooth_dp_utils.batch_dijkstra(M_pred.detach().cpu().numpy(), 
                                                   np.repeat(np.array([[0,N**2-1]]), N_eval, 0))

        path_pred_map_all = torch.zeros((N_eval, N**2))
        for i in range(0, N_eval):
            path_pred_map = torch.zeros((N**2,))
            path_pred_map[path_pred[i]] = 1
            path_pred_map_all[i] = path_pred_map

        path_pred_map_all = path_pred_map_all.reshape(-1,N,N)

        cost_pred = (path_pred_map_all*val_weights[:N_eval]).sum(-1).sum(-1)

        cost_true = (val_labels.astype(float)[:N_eval]*val_weights[:N_eval]).sum(-1).sum(-1)

        perc_correct = (cost_pred - cost_true < 0.001).sum()/N_eval
        perc_correct_2 = (cost_pred - cost_true < 0.1).sum()/N_eval
        perc_correct_3 = (cost_pred - cost_true < 0.5).sum()/N_eval
    
    
        if perc_correct<=perc_correct_best:
            not_best_count_accum = not_best_count_accum + 1
            print('Did not improve results nr ', not_best_count_accum)
        else:
            perc_correct_best = perc_correct
            not_best_count_accum = 0
            _ = utils.check_or_create_folder("saved_models")
            torch.save(model.state_dict(), model_path)
            
        print(epochs, 
              ': Batches AVG:', round(loss_batch_avg.item(), 4), 
              '\t VAL perc:', round(perc_correct.item(), 4),
              '\t VAL perc <0.1:', round(perc_correct_2.item(), 4),
              '\t VAL perc <0.5:', round(perc_correct_3.item(), 4),
             )
        scheduler.step()
        
        if not_best_count_accum >= epochs_wait:
            print('Converged, exiting')
            break   