"""This script is used for generating training and evaluation datasets for the partially dynamic travelling salesman problem"""

import argparse
import numpy as np
from tqdm import tqdm

def norm_2D(first_coords, second_coords):
    return ((first_coords[0] - second_coords[0])**2 + (first_coords[1] - second_coords[1])**2)**0.5

def calc_insert_cost(D, prv, nxt, ins):
    return D[prv][ins] + D[ins][nxt] - D[prv][nxt]

def get_distance_matrix(nodes_coords):
    return np.array([[norm_2D(nodes_coords[i], nodes_coords[j]) for j in range(len(nodes_coords))] for i in range(len(nodes_coords))])

def find_insertion_position(D, tour, current):
    tour_current_and_start = tour
    tour_current_and_start.append(0)
    tour_current_and_start.insert(0, current)
    best_cost = float('inf')
    best_position = None
    for i in range(1, len(tour_current_and_start)):
        prv = tour_current_and_start[i - 1]
        nxt = tour_current_and_start[i]
        cost = calc_insert_cost(D, prv, nxt, -1)
        if cost < best_cost:
            best_cost = cost
            best_position = i
    return best_position

def truncated_bivariate_normal(alimits, blimits, mean, cov, n_samples):
    samples = np.zeros((0, 2))   # 2 columns now
    while samples.shape[0] < n_samples: 
        s = np.random.multivariate_normal(mean, cov, size=(n_samples,))
        accepted = s[(np.min(s - [alimits[0], blimits[0]], axis=1) >= 0) & (np.max(s - [alimits[1], blimits[1]], axis=1) <= 0)]
        samples = np.concatenate((samples, accepted), axis=0)
    samples = samples[:n_samples, :]
    return samples

def generate_valset(nodes_coords: list[list[float]], tour_nodes: list[int], proportions: list[float], distributions: list[str], means: list[list[float]], covariances: list[list[list[float]]], n_new_nodes: int =20, seed: int =123):
    old_coords = np.array(nodes_coords)
    old_tour = tour_nodes
    # set seed
    np.random.seed(seed)

    probabilities = []

    for distribution in distributions:
        if distribution == "am":
            p = np.zeros_like(range(2, len(nodes_coords) + n_new_nodes))
            p[:len(p)//2] = 2
            p[len(p)//2:] = 1
            p = p/np.sum(p)
        elif distribution == "pm":
            p = np.zeros_like(range(2, len(nodes_coords) + n_new_nodes))
            p[:len(p)//2] = 1
            p[len(p)//2:] = 2
            p = p/np.sum(p)
        elif distribution == "equal":
            p = np.ones_like(range(2, len(nodes_coords) + n_new_nodes))
            p = p/np.sum(p)
        probabilities.append(p)
    
    # sample from a multinomial distribution to find the proportion of nodes sampled at each location
    n_loc = np.random.multinomial(n_new_nodes, proportions, size=1)[0]
    times_list = []
    new_coords_list = []
    available_times = list(range(2, len(nodes_coords) + n_new_nodes))
    for location in range(len(proportions)):
        probs = probabilities[location][[i-2 for i in available_times]]/np.sum(probabilities[location][[i-2 for i in available_times]])
        times = np.random.choice(available_times, n_loc[location], replace=False, p=probs)
        times_list.extend(times)
        available_times = [x for x in available_times if x not in times]
        new_coords_list.extend(truncated_bivariate_normal([0, 1], [0, 1], means[location], covariances[location], n_loc[location]))

    sorted_indices = np.argsort(times_list)

    times = np.array(times_list)[sorted_indices]
    new_coords = np.array(new_coords_list)[sorted_indices]
    new_tour = old_tour.copy() # copy the old tour
    for i, t in enumerate(times):
        # current node is the node we are at at time t
        current_node = new_tour[t-1]
        # unvisited nodes is the remaining tour
        unvisited_nodes = new_tour[t:]
        # append the co-ordinate of the newly arrived node onto the list of coordinates
        old_and_new_coords = np.append(old_coords, np.expand_dims(new_coords[i],0),axis=0)
        # find the best position in the tour to insert the new node
        D = get_distance_matrix(old_and_new_coords)
        best_position = find_insertion_position(D, unvisited_nodes, current_node)
        best_position = best_position + t
        # the newly added node gets inserted into the tour
        new_tour.insert(best_position, len(old_and_new_coords) - 1)
        old_coords = old_and_new_coords
    
    return nodes_coords, new_coords, times, new_tour


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_static", type=int, default=20)
    parser.add_argument("--num_dynamic", type=int, default=20)
    parser.add_argument("--distribution", type=str, default="unimodal")
    parser.add_argument("--seed", type=int, default=123456)
    parser.add_argument("--filename", type=str, default=None)
    parser.add_argument("--num_instances", type=int, default=1280)
    opts = parser.parse_args()

static_node_file = "data/tsp/tsp" + str(opts.num_static) + "_testset.txt"

nodes_coords = []
tour_nodes = []

print('\nLoading from {}...'.format(static_node_file))
for line in tqdm(open(static_node_file, "r").readlines(), ascii=True):
    line = line.split(" ")
    num_nodes = int(line.index('output')//2)
    nodes_coords.append(
        [[float(line[idx]), float(line[idx + 1])] for idx in range(0, 2 * num_nodes, 2)]
    )

    tour_nodes_temp = [int(node) - 1 for node in line[line.index('output') + 1:-1]][:-1]
    tour_nodes.append(tour_nodes_temp)

all_old_nodes_coords = []
all_new_nodes_coords = []
all_tours = []
all_times = []

seed=opts.seed

print('\nAdding {} Dynamic Nodes to Each Instance...'.format(opts.num_dynamic))
for nodes_item, tour_item in tqdm(zip(nodes_coords, tour_nodes), total=len(nodes_coords), ascii=True):
    seed = seed + 1
    if opts.distribution == "unimodal":
        old_nodes, new_nodes, times, tour = generate_valset(nodes_item, tour_item, [1], ["equal"], [[0.2, 0.2]], [[[0.05, 0], [0, 0.05]]], n_new_nodes=opts.num_dynamic, seed=seed)
    elif opts.distribution == "bimodal":
        old_nodes, new_nodes, times, tour = generate_valset(nodes_item, tour_item, [0.5, 0.5], ["am", "pm"], [[0.2,0.2], [0.8,0.8]], [[[0.1, 0], [0, 0.1]], [[0.1, 0], [0, 0.1]]], n_new_nodes=opts.num_dynamic, seed=seed)
    all_old_nodes_coords.append(old_nodes)
    all_new_nodes_coords.append(new_nodes)
    all_times.append(times)
    all_tours.append(tour)

# Now we should have 1280 instances of tsp with arrivals and a comparison solution which is concorde + an insertion heuristic 
# Step 3, take this newly generated data and save it to a txt file. Format should be as follows: each line contains old_node_coords, new_node_coords, times, solution.

filename = "data/tsp/" + opts.filename

with open(filename, "w") as f:
    for i in range(opts.num_instances):
        f.write( " ".join( str(x)+str(" ")+str(y) for x,y in all_old_nodes_coords[i]))
        f.write( str(" ") + str("new_nodes" + str(" ")))
        f.write( str(" ").join( str(x)+str(" ")+str(y) for x,y in all_new_nodes_coords[i]))
        f.write( str(" ") + str("times" + str(" ")))
        f.write( str(" ").join(str(x) for x in all_times[i]))
        f.write( str(" ") + str('output') + str(" ") )
        f.write( str(" ").join(str(node_idx+1) for node_idx in all_tours[i]) )
        f.write( str(" ") + str(all_tours[i][0]+1) + str(" ") )
        f.write( "\n" )