import argparse
import datetime
import os
import time
from pickle import load

import matplotlib.pyplot as plt
import numpy as np
from heuristics.local_search import LocalSearch
from heuristics.trip_division import TripDivision
from mm_cvrp.utils import get_src_vector
from utils.capacity import node2capacity


def plot(paths, path_length_list, locations, output, data_path, idx, after_local_search=False):
    plt.figure(figsize=(10, 10))
    plt.scatter(locations[:, 0], locations[:, 1], s=10, c="gray")
    cmap = plt.get_cmap("tab10")

    dist_list = []
    for i, path in enumerate(paths):
        dist = 0
        flag = True
        for src, dst in zip(path, path[1:], strict=False):
            dist += np.linalg.norm(locations[src] - locations[dst])
            plt.plot(
                [locations[src, 0], locations[dst, 0]],
                [locations[src, 1], locations[dst, 1]],
                color=cmap(i % 10),
                label=f"{len(path) - 2} : {path_length_list[i]}" if flag else None,
            )
            plt.text(locations[src, 0], locations[src, 1], src, color=cmap(i % 10))
            flag = False
        dist += np.linalg.norm(locations[dst] - locations[0])
        dist_list.append(dist)

    plt.legend()
    plt.title(max(path_length_list))
    os.makedirs(output, exist_ok=True)
    if not after_local_search:
        filepath = f"{output}/{data_path.replace('/', ':', -1)}_{idx}.png"
    else:
        filepath = f"{output}/LS_{data_path.replace('/', ':', -1)}_{idx}.png"
    plt.savefig(filepath)
    print(filepath)
    print(datetime.datetime.now())

    return min(dist_list)


def main(args: argparse.Namespace) -> None:
    with open(args.input, "rb") as f:
        locations = load(f)
    output = args.output + "/" + str(args.n_agent) + "agent"
    timebox = []
    box = []
    ls_timebox = []
    ls_box = []
    for i in range(min(args.n_instance, len(locations))):
        n_node = len(locations[i])
        src_vector = get_src_vector(n_node, args.n_agent)
        cap = node2capacity[n_node]
        src_vector = [cap] * args.n_agent
        begin = time.time()
        initial_solution_model = TripDivision(
            src_vector=src_vector,
            n_agent=args.n_agent,
            n_node=args.n_node,
            locations=locations[i],
            force_random_sampling=True,
        )
        path_list, path_length_list = initial_solution_model()
        timebox.append(time.time() - begin)
        box.append(max(path_length_list))
        plot(path_list, path_length_list, locations[i], output, args.input, i)

        begin = time.time()
        local_search_model = LocalSearch(
            capacity=cap,
            n_agent=args.n_agent,
            n_node=args.n_node,
            n_iter=300,
            locations=locations[i],
            disable_plot=True,
        )
        path_list, path_length_list = local_search_model(path_list, path_length_list)
        ls_timebox.append(time.time() - begin)
        ls_box.append(max(path_length_list))
        plot(path_list, path_length_list, locations[i], output, args.input, i, True)
    print(sum(timebox) / len(timebox))
    txt_filename = f"{output}/before_{args.input.replace('/', ':', -1)}.txt"
    with open(txt_filename, "w") as f:
        for i, row in enumerate(box):
            print(i, row, file=f)
        print("ave", round(sum(box) / len(box), 3), round(sum(timebox) / len(timebox), 3), file=f)
    print(txt_filename)

    txt_filename = f"{output}/after_{args.input.replace('/', ':', -1)}.txt"
    with open(txt_filename, "w") as f:
        for i, row in enumerate(ls_box):
            print(i, row, file=f)
        print("ave", round(sum(ls_box) / len(ls_box), 3), round(sum(ls_timebox) / len(ls_timebox), 3), file=f)
    print(txt_filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--n-agent", type=int, default=5, help="number of agents")
    parser.add_argument("--n-node", type=int, default=30, help="number of agents")
    parser.add_argument("--n-instance", type=int, default=3, help="number of instance")
    parser.add_argument("--output", type=str, default="heuristic_result", help="output-folder")
    parser.add_argument("--input", type=str, required=True, help="input-data")

    args = parser.parse_args()
    main(args)
