import argparse
import datetime
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from mip.cvrp import MIP_CVRP
from mip.cvrp_sum import MIP_CVRP_SUM
from mip.mip_utils import SolverChoices
from mip.mip_utils import SubtourEliminationMethods
from mm_cvrp.utils import get_src_vector


def plot(paths, locations, output, data_path, idx):
    plt.figure(figsize=(10, 10))
    plt.scatter(locations[:, 0], locations[:, 1], s=10, c="gray")
    cmap = plt.get_cmap("tab10")

    dist_list = []
    for i, path in enumerate(paths):
        dist = 0
        for src, dst in zip(path, path[1:], strict=False):
            dist += np.linalg.norm(locations[src] - locations[dst])
            plt.plot(
                [locations[src, 0], locations[dst, 0]], [locations[src, 1], locations[dst, 1]], color=cmap(i % 10)
            )
            plt.text(locations[src, 0], locations[src, 1], src, color=cmap(i % 10))
        dist += np.linalg.norm(locations[dst] - locations[0])
        dist_list.append(dist)

    os.makedirs(output, exist_ok=True)
    filepath = f"{output}/{data_path.replace('/', ':', -1)}_{idx}.png"
    plt.savefig(filepath)
    print(filepath)
    print(datetime.datetime.now())

    return min(dist_list)


def main(args: argparse.Namespace) -> None:
    locations = torch.load(args.input).numpy()
    box = []
    output = args.output + "/" + args.formulation + "_" + args.constraint + "_" + str(args.n_agent) + "agent"

    for i in range(min(args.n_instance, len(locations))):
        n_node = len(locations[i])
        src_vector = get_src_vector(n_node - 1, args.n_agent)
        if args.formulation == "maxmin":
            model = MIP_CVRP(
                src_vector=src_vector,
                n_agent=args.n_agent,
                n_node=n_node,
                locations=locations[i],
                subtour_elimination=args.constraint,
            )
        else:
            model = MIP_CVRP_SUM(
                src_vector=src_vector,
                n_agent=args.n_agent,
                n_node=n_node,
                locations=locations[i],
                subtour_elimination=args.constraint,
            )
        paths, duration = model(
            solver_type=args.solver,
            timelimit=args.timelimit,
            n_thread=args.n_thread,
            log_path=f"{output}/log_{os.path.basename(args.input)}_{i}.txt",
            show_log=args.show_log,
        )
        dist = plot(paths, locations[i], output, args.input, i)
        box.append((dist, duration))

    txt_filename = f"{output}/{os.path.basename(args.input)}.txt"
    ave_distance = round(sum([dist for (dist, _) in box]) / len(box), 3)
    ave_duration = round(sum([duration for (_, duration) in box]) / len(box), 3)
    with open(txt_filename, "w") as f:
        print("idx dist duration", file=f)
        for i, (dist, duration) in enumerate(box):
            print(i, dist, duration, file=f)
        print("ave", ave_distance, ave_duration, file=f)
    print(txt_filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--n-agent", type=int, default=5, help="number of agents")
    parser.add_argument("--n-instance", type=int, default=3, help="number of instance")
    parser.add_argument("--output", type=str, default="mip_result", help="output-folder")
    parser.add_argument("--input", type=str, required=True, help="input-data")
    parser.add_argument("--solver", choices=SolverChoices.getValues(), default="gurobi", help="solver")
    parser.add_argument(
        "--constraint", choices=SubtourEliminationMethods.getValues(), required=True, help="subtour elimination method"
    )
    parser.add_argument(
        "--formulation", choices=["maxmin", "summin"], required=True, help="subtour elimination method"
    )
    parser.add_argument("--timelimit", type=int, default=3600, help="number of instance")
    parser.add_argument("--n-thread", type=int, default=8, help="number of instance")
    parser.add_argument("--show-log", action="store_true", help="number of instance")

    args = parser.parse_args()
    main(args)
