import argparse
import os
from pickle import load

import torch
from utils.capacity import node2capacity


def write_problem(
    locations: torch.tensor, num_node: int, num_agent: int, problem_name: str, filepath: str, capacity: int
) -> None:
    with open(filepath, "w") as f:
        print(f"NAME : {problem_name}", file=f)
        print(f"COMMENT : {problem_name} tsp", file=f)
        print("TYPE : CVRP", file=f)
        print(f"DIMENSION : {num_node}", file=f)
        print(f"VEHICLES : {num_agent}", file=f)
        # print(f"CAPACITY : {num_node // num_agent + 1}", file=f)
        print(f"CAPACITY : {capacity}", file=f)
        print("EDGE_WEIGHT_TYPE : EUC_2D", file=f)
        print("NODE_COORD_SECTION", file=f)
        for i, (x, y) in enumerate(locations):
            print(f"{i+1} {x} {y}", file=f)
        print("EOF", file=f)


def write_parameter(filepath: str, problem_name: str, unit, output, timelimit, initial_tour_file_path=None) -> None:
    with open(filepath, "w") as f:
        print(f"PROBLEM_FILE = {output}/{problem_name}.tsp", file=f)
        print(f"OUTPUT_TOUR_FILE = {output}/{problem_name}.txt", file=f)
        print("MTSP_OBJECTIVE = MINMAX", file=f)
        print("MOVE_TYPE = 5", file=f)
        print("PATCHING_C = 3", file=f)
        print("PATCHING_A = 2", file=f)
        print("RUNS = 1", file=f)
        print(f"MTSP_MIN_SIZE = {unit}", file=f)
        print("SCALE = 1000", file=f)
        if timelimit != -1:
            print(f"TOTAL_TIME_LIMIT = {timelimit}", file=f)
        if initial_tour_file_path is not None:
            print(f"INITIAL_TOUR_FILE = {initial_tour_file_path}", file=f)


def main(args: argparse.Namespace) -> None:
    if args.timelimit != -1 and args.timelimit <= 0:
        raise ValueError("specify appropriate timelimit")
    # data = torch.load(args.input)
    with open(args.input, "rb") as f:
        data = load(f)
    if args.timelimit == -1:
        output = args.output
    else:
        output = f"{args.output}_{args.timelimit}s"
    os.makedirs(output, exist_ok=True)

    num_node = data.shape[1]
    capacity = node2capacity[num_node]

    phrase = args.input.split("/")[-2] + "_" + os.path.basename(args.input)
    for i in range(min(data.shape[0], args.num_instance)):
        filepath = f"{output}/{phrase}_{i}.tsp"
        problem_name = f"{phrase}_{i}"
        print(filepath)
        write_problem(data[i], num_node, args.n_agent, problem_name, filepath, capacity)
        filepath = f"{output}/{phrase}_{i}.par"
        unit = num_node // args.n_agent
        write_parameter(filepath, problem_name, unit, output, args.timelimit)
        print(filepath)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input", type=str, required=True, help="input")
    parser.add_argument("-a", "--n-agent", type=int, required=True, help="num agent")
    parser.add_argument("-n", "--num-instance", type=int, required=True, help="num instance")
    parser.add_argument("-t", "--timelimit", type=int, default=-1, help="num instance")
    parser.add_argument("--output", type=str, default="LKH_problem_exp", help="output-folder")

    args = parser.parse_args()
    main(args)
