"""
BQ-NCO
Copyright (c) 2023-present NAVER Corp.
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
"""

import argparse
import os.path
from functools import partial
from multiprocessing import Pool

import numpy as np
from concorde.tsp import TSPSolver
from scipy.spatial.distance import pdist, squareform

SCALE = 1e6


def solve(instance_coords, reorder, time_bound):
    solver = TSPSolver.from_data(
        instance_coords[:, 0] * SCALE, instance_coords[:, 1] * SCALE, norm="EUC_2D"
    )
    solution = solver.solve(time_bound=time_bound)
    solution_closed_tour = list(solution[0]) + [0]

    if reorder:
        coords_reordered = instance_coords[np.array(solution_closed_tour)]
        return coords_reordered, np.arange(len(solution[0])), None
    else:
        instance_coords = instance_coords.tolist()
        instance_coords.append(instance_coords[0])

        # compute tour length
        adj_matrix = squareform(pdist(instance_coords, metric="euclidean"))
        tour_len = sum(
            [
                adj_matrix[solution_closed_tour[i], solution_closed_tour[i + 1]]
                for i in range(len(solution_closed_tour) - 1)
            ]
        )
        return instance_coords, np.array(solution[0]), tour_len


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="generate and solve TSP")
    parser.add_argument("--num_instances", type=int, default=10, help="Numbers of TSP instances")
    parser.add_argument("--num_nodes", type=int, default=100, help="Numbers of nodes")
    parser.add_argument("--output_filename", type=str, required=True, help="Output directory")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--reorder", dest="reorder", action="store_true", help="Reorder nodes/tours. training dataset MUST BE reordered")
    parser.add_argument("--parallel", action="store_true", help="Solve instances in parallel")
    parser.add_argument("--time_bound", type=int, default=-1, help="Time limit to solve one instance in seconds")

    args = parser.parse_args()
    np.random.seed(args.seed)

    all_instance_coords = np.random.random([args.num_instances, args.num_nodes, 2])

    if args.parallel:
        with Pool() as pool:
            outputs = pool.map(
                partial(solve, reorder=args.reorder, time_bound=args.time_bound),
                all_instance_coords,
            )
    else:
        outputs = [
            solve(instance_coords, args.reorder, args.time_bound)
            for instance_coords in all_instance_coords
        ]

    coords = [s[0] for s in outputs]
    solutions = [s[1] for s in outputs]
    tour_lens = [s[2] for s in outputs]

    if args.reorder:
        np.savez_compressed(
            os.path.join(args.output_filename),
            coords=np.array(coords),
            reorder=True,
            solutions=np.stack(solutions),
        )
    else:
        np.savez_compressed(
            os.path.join(args.output_filename),
            coords=np.array(coords),
            tour_lens=np.array(tour_lens),
            solutions=np.stack(solutions),
            reorder=False,
        )

    print("Data transformed and saved to " + args.output_filename)
