"""This script is used for generating a txt file of tour lengths of different solution methods for a provided file of test instances"""

import argparse
from problems.tsp.dynamic_tsp import DTSP
from utils import load_model
from train import set_decode_type
import torch
from torch.utils.data import DataLoader
from eval_baseline import _calc_insert_cost, run_insertion, distance_matrix
from tqdm import tqdm
import numpy as np
from utils import move_to



def get_costs(dataset, pi):

    """Returns DTSP tour length for given graph nodes and tour permutations
        This should remain fine provided that it is only used to find the cost of completed DTSP tours


    Args:
        dataset: graph nodes (torch.Tensor)
        pi: node permutations representing tours (torch.Tensor)
        
    Returns:
        TSP tour length, None
    """

    # Check that tours are valid, i.e. contain 0 to n -1
    assert (
        torch.arange(pi.size(1), out=pi.data.new()).view(1, -1).expand_as(pi) ==
        pi.data.sort(1)[0]
    ).all(), "Invalid tour:\n{}\n{}".format(dataset, pi)

    # Gather dataset in order of tour
    d = dataset.gather(1, pi.unsqueeze(-1).expand_as(dataset))

    # Length is distance (L2-norm of difference) from each next location from its prev and of last from first

    edge_lengths = torch.cat(((d[:, 1:] - d[:, :-1]).norm(p=2, dim=2), (d[:, 0] - d[:, -1]).norm(p=2, dim=1).unsqueeze(-1)), dim=1)

    return edge_lengths.sum(1), edge_lengths

def run_partial_insertion(nodes, visited_tour, unvisited_tour, method):
    # the current node is the final node in the visited tour
    current_node = nodes[visited_tour[-1]]
    starting_node = nodes[0]
    unvisited_nodes = nodes[unvisited_tour]

    loc = np.concatenate((np.expand_dims(current_node,0), unvisited_nodes, np.expand_dims(starting_node,0)))
    n = len(loc)
    D = distance_matrix(loc, loc)

    mask = np.zeros(n, dtype=bool)
    mask[0] = True
    mask[-1] = True
    tour = [0, n-1] 
    for i in range(1,n-1): # iterate over nodes we haven't visited yet
        feas = mask == 0
        feas_ind = np.flatnonzero(mask == 0)
        if method == 'random':
            # Order of instance is random so do in order for deterministic results
            a = i
        elif method == 'nearest':
            if i == 0:
                a = 0  # order does not matter so first is random
            else:
                a = feas_ind[D[np.ix_(feas, ~feas)].min(1).argmin()] # node nearest to any in tour
        elif method == 'cheapest':
            assert False, "Not yet implemented" # try all and find cheapest insertion cost

        elif method == 'farthest':
            if i == 0:
                a = D.max(1).argmax()  # Node with farthest distance to any other node
            else:
                a = feas_ind[D[np.ix_(feas, ~feas)].min(1).argmax()]  # node which has closest node in tour farthest
        mask[a] = True

        if len(tour) == 0:
            tour = [a]
        else:
            # Find index with least insert cost
            ind_insert = np.argmin(
                _calc_insert_cost(
                    D,
                    tour[:-1],
                    tour[1:],
                    a
                )
            )
            tour.insert(ind_insert + 1, a)

    cost = D[tour, np.roll(tour, -1)].sum()
    return cost, tour

def rerun_insertion(nodes_batch, new_nodes_batch, times_batch, method):
    tours = []
    batch_size = times_batch.size()[0]
    for i in range(batch_size):
        # Convert from tensors to np arrays for ease of use
        nodes = np.array(nodes_batch[i])
        new_nodes = np.array(new_nodes_batch[i])
        times = np.array(times_batch[i])

        # Generate the first tour before any arrivals happen 
        _, old_tour = run_insertion(nodes, method)

        # f = plt.figure()
        # p = f.add_subplot(111)
        # plot_tsp_tour(p, nodes, old_tour, '')
        tour = old_tour.copy()
        for i, t in enumerate(times):
            # append the node corresponding to the new arrival to the list of nodes
            nodes = np.concatenate((nodes, np.expand_dims(new_nodes[i], 0)))
            # the visited part of the tour is the nodes stepped to before time t, this cannot be changed by the heuristic
            visited_tour = tour[:t]
            # the unvisited nodes whose order is allowed to be changed by the heuristic
            unvisited_tour = tour[t:]
            # add an index corresponding to the new node at the end of the current unvisited node. Position doesn't matter at the moment because it will be changed by the heuristic anyway 
            unvisited_tour = np.append(unvisited_tour, len(nodes)-1)
            # call the modified insertion heuristic on the unvisited tour
            _, remaining_tour_ix = run_partial_insertion(nodes, visited_tour, unvisited_tour, method)
            remaining_tour_ix = remaining_tour_ix[1:-1] #snip off the start and end nodes which correspond to the current node and the first node
            # need to map the indices retruned to their corresponding nodes in the original
            remaining_tour = [unvisited_tour[i-1] for i in remaining_tour_ix]
            # indices = np.arange(len(tour))
            # visited_tour = tour[:t]
            # unvisited_nodes = np.delete(nodes, visited_tour, axis=0)
            # unvisited_nodes = np.insert(unvisited_nodes, 0, nodes[visited_tour[-1]], axis=0)
            # print(unvisited_nodes)
            # indices = np.append(np.delete(indices, visited_tour), len(tour))
            # _, unvisited_tour = run_insertion(unvisited_nodes, method)
            # unvisited_tour = indices[[i-1 for i in unvisited_tour[1:]]]
            tour = np.append(visited_tour, remaining_tour)
            
            # f = plt.figure()
            # p = f.add_subplot(111)
            # plot_tsp_tour(p, nodes, tour, '')

        tours.append(tour)
        
    return torch.tensor(np.array(tours))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--source_file", type=str, default=None)
    parser.add_argument("--filename", type=str, default=None)
    parser.add_argument("--model", type=str, default="outputs/dtsp_20-50/rl-ar-var-20pnn-gnn-max-dtsp_20240909T132609")

    opts = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 16
accumulation_steps = 80
num_samples = 128 # 1280 samples per TSP size 

neighbors = 0.20
knn_strat = 'percentage'

model, model_args = load_model(opts.model, extra_logging=True)
set_decode_type(model, "greedy")
model.eval()
model.to(device)

dataset = DTSP.make_dataset(
    filename=opts.source_file, batch_size=batch_size, num_samples=num_samples, 
    neighbors=neighbors, knn_strat=knn_strat, supervised=True
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

with torch.no_grad():

    policy_tour_lengths = []
    concorde_insertion_tour_lengths = []
    rerun_insertion_tour_lengths = []
    total_times = []

    for bat_idx, bat in enumerate(tqdm(dataloader, ascii=True)):

        nodes = bat['nodes']
        new_nodes = bat['new_nodes']
        graph = bat['graph']
        times = bat['times']
        tour_nodes = bat['tour_nodes']
        total_nodes = torch.cat((nodes, new_nodes), dim=1)

        _, ll, pi = model(move_to(nodes,device), move_to(graph,device), move_to(times,device), move_to(new_nodes,device), return_pi=True)
        gt_cost, edge_lengths = get_costs(total_nodes, bat['tour_nodes'])
        cost, edge_lengths_pi = get_costs(total_nodes, pi)

        insertion_tour = rerun_insertion(nodes, new_nodes, times, 'nearest')
        insertion_cost, _ = get_costs(total_nodes, insertion_tour)

        policy_tour_lengths.extend(cost)
        concorde_insertion_tour_lengths.extend(gt_cost)
        rerun_insertion_tour_lengths.extend(insertion_cost)

policy_tour_lengths = np.array(policy_tour_lengths).round(2)
concorde_insertion_tour_lengths = np.array(concorde_insertion_tour_lengths).round(2)
rerun_insertion_tour_lengths = np.array(rerun_insertion_tour_lengths).round(2)

filename = "outputs/" + opts.filename

with open(filename, "w") as f:
    f.write( " ".join(str(x) for x in policy_tour_lengths))
    f.write( str(" ") + str("concorde_with_insertion" + str(" ")))
    f.write( str(" ").join(str(x) for x in concorde_insertion_tour_lengths))
    f.write( str(" ") + str("rerun_insertion" + str(" ")))
    f.write( str(" ").join(str(x) for x in rerun_insertion_tour_lengths))
    f.write( "\n" )
