import argparse
import math
import os
import time

import matplotlib.pyplot as plt
import numpy as np
from heuristics.local_search import LocalSearch
from utils.capacity import node2capacity


def read_problem_file(problem_file: str) -> tuple[np.ndarray, int, int]:
    location = []
    key = "NODE_COORD_SECTION"
    flag = False
    num_location = None
    num_vehicle = None
    with open(problem_file, "r") as f:
        for row in f:
            if not flag and key not in row:
                if "DIMENSION" in row:
                    num_location = int(row.split()[-1])
                if "VEHICLES" in row:
                    num_vehicle = int(row.split()[-1])
                continue
            elif not flag and key in row:
                flag = True
            elif "EOF" not in row:
                _, x, y = row.split()
                location.append([float(x), float(y)])

    return np.array(location), num_location, num_vehicle


def read_solution_file(solution_file: str, num_location: int, num_vehicles: int) -> list[list[int]]:
    all_tour = []
    key = "TOUR_SECTION"
    flag = False
    single_tour = []
    depot = 0
    with open(solution_file, "r") as f:
        for row in f:
            if not flag and key not in row:
                continue
            elif not flag and key in row:
                flag = True
            elif "-1" in row:
                flag = False
            else:
                # -1はlocationのidxと対応づけるため
                idx = int(row.split()[0]) - 1
                if idx >= num_location:
                    single_tour.append(depot)
                    all_tour.append(single_tour)
                    single_tour = [depot]
                else:
                    single_tour.append(idx)

        single_tour.append(depot)
        all_tour.append(single_tour)

    if len(all_tour) != num_vehicles:
        raise ValueError("num is different")

    return all_tour


def get_cost(location: np.ndarray, all_tour: list[list[int]]) -> list[int]:
    def get_euclidean_distance(location1, location2):
        return math.sqrt((location1[0] - location2[0]) ** 2 + (location1[1] - location2[1]) ** 2)

    cost_box = []
    for _, tour in enumerate(all_tour):
        dist = 0
        for i, j in zip(tour, tour[1:], strict=False):
            dist += get_euclidean_distance(location1=location[i], location2=location[j])
        cost_box.append(round(dist, 3))

    return cost_box


def plot(
    location: np.ndarray,
    all_tour: list[list[int]],
    cost_box: list[float],
    filepath: str,
) -> None:
    plt.figure(figsize=(9, 9))
    plt.scatter(location[:, 0], location[:, 1], c="gray")
    cmap = plt.get_cmap("tab10")
    for i, tour in enumerate(all_tour):
        flag = True
        for src, dst in zip(tour, tour[1:], strict=False):
            plt.plot(
                [location[src, 0], location[dst, 0]],
                [location[src, 1], location[dst, 1]],
                c=cmap(i % 10),
                label=f"{cost_box[i]} : {len(tour) - 2}" if flag else None,
            )
            if flag:
                flag = False

    plt.legend()
    plt.title(max(cost_box))

    plt.savefig(filepath)
    plt.close()
    # print(filepath)


def main(args: argparse.Namespace) -> None:
    location, num_location, num_vehicle = read_problem_file(args.problem_input)
    if not num_vehicle:
        num_vehicle = 1
    all_tour = read_solution_file(args.solution_input, num_location, num_vehicle)
    cost_box = get_cost(location, all_tour)

    os.makedirs(args.output, exist_ok=True)
    filepath = f"{args.output}/{num_location}node_{num_vehicle}agent_{os.path.basename(args.problem_input)}.png"

    plot(location, all_tour, cost_box, filepath)

    n_node = len(location)
    cap = node2capacity[n_node]

    start = time.time()
    local_search_model = LocalSearch(
        capacity=cap,
        n_agent=num_vehicle,
        n_node=n_node,
        n_iter=300,
        locations=location,
        disable_plot=True,
    )
    path_list, path_length_list = local_search_model(all_tour, cost_box)
    duration = time.time() - start
    filepath = f"{args.output}/LS_{num_location}node_{num_vehicle}agent_{os.path.basename(args.problem_input)}.png"
    plot(location, path_list, path_length_list, filepath)

    return cost_box, path_length_list, duration


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-p", "--problem-input", type=str, required=True, help="path to problem file")
    parser.add_argument("-s", "--solution-input", type=str, required=True, help="path to solution file")
    parser.add_argument("--output", type=str, default="LKH_output", help="output-folder")

    args = parser.parse_args()
    main(args)
