import argparse
import datetime
import multiprocessing
import os
from pickle import load

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from mip.cvrp import MIP_CVRP
from mip.cvrp_sum import MIP_CVRP_SUM
from mip.mip_utils import SolverChoices
from mip.mip_utils import SubtourEliminationMethods
from mm_cvrp.utils import get_src_vector
from utils.capacity import node2capacity


def execute(part_target_list, queue):
    result_box = []
    for idx, formulation, problem_params, location, solver_params, log_path, output, plot in part_target_list:
        if formulation == "maxmin":
            model = MIP_CVRP(
                src_vector=problem_params["src_vector"],
                n_agent=problem_params["n_agent"],
                n_node=problem_params["n_node"],
                locations=location,
                subtour_elimination=problem_params["subtour_elimination"],
            )
        else:
            model = MIP_CVRP_SUM(
                src_vector=problem_params["src_vector"],
                n_agent=problem_params["n_agent"],
                n_node=problem_params["n_node"],
                locations=location,
                subtour_elimination=problem_params["subtour_elimination"],
            )
        paths, duration = model(
            solver_type=solver_params["solver_type"],
            timelimit=solver_params["timelimit"],
            n_thread=solver_params["n_thread"],
            log_path=log_path,
            show_log=solver_params["show_log"],
        )
        dist = plot(paths, location, output, args.input, idx)
        result_box.append((idx, dist, duration))
    queue.put(result_box)


def split_list(a, m):
    avg_len = len(a) // m
    remainder = len(a) % m

    result = []
    start = 0

    for i in range(m):
        # 基本の長さに、余りがある限り1を足す
        end = start + avg_len + (1 if i < remainder else 0)
        result.append(a[start:end])
        start = end

    return result


def plot(paths, locations, output, data_path, idx):
    plt.figure(figsize=(10, 10))
    plt.scatter(locations[:, 0], locations[:, 1], s=10, c="gray")
    cmap = plt.get_cmap("tab10")

    dist_list = []
    for i, path in enumerate(paths):
        dist = 0
        for src, dst in zip(path, path[1:], strict=False):
            dist += np.linalg.norm(locations[src] - locations[dst])
            plt.plot(
                [locations[src, 0], locations[dst, 0]], [locations[src, 1], locations[dst, 1]], color=cmap(i % 10)
            )
            plt.text(locations[src, 0], locations[src, 1], src, color=cmap(i % 10))
        dist += np.linalg.norm(locations[dst] - locations[0])
        dist_list.append(dist)

    os.makedirs(output, exist_ok=True)
    filepath = f"{output}/{data_path.replace('/', ':', -1)}_{idx}.png"
    plt.savefig(filepath)
    print(filepath)
    print(datetime.datetime.now())

    return min(dist_list)


def main(args: argparse.Namespace) -> None:
    # locations = torch.load(args.input).numpy()
    with open(args.input, "rb") as f:
        locations = load(f)
    box = []
    output = args.output + "/" + args.formulation + "_" + args.constraint + "_" + str(args.n_agent) + "agent"

    # queue = multiprocessing.Queue()
    # processes = [multiprocessing.Process(target=execute, args=(filepath_chunk_list[i], )) for i in range(m)]
    # for p in processes:
    #     p.start()
    # for p in processes:
    #     p.join()

    n_node = len(locations[0])
    # FIXME : capacityの取得方法を変える
    # src_vector = get_src_vector(n_node - 1, args.n_agent)
    src_vector = [node2capacity[n_node] for _ in range(args.n_agent)]

    problem_params = {
        "src_vector": src_vector,
        "n_agent": args.n_agent,
        "n_node": n_node,
        "subtour_elimination": args.constraint,
    }
    solver_params = {
        "solver_type": args.solver,
        "timelimit": args.timelimit,
        "n_thread": args.n_thread,
        "show_log": args.show_log,
    }
    target_list = [
        (
            i,
            args.formulation,
            problem_params,
            locations[i],
            solver_params,
            f"{output}/log_{os.path.basename(args.input)}_{i}.txt",
            output,
            plot,
        )
        for i in range(min(args.n_instance, len(locations)))
    ]
    m = args.num_process
    target_chunk_list = split_list(target_list, m)

    queue = multiprocessing.Queue()

    processes = [multiprocessing.Process(target=execute, args=(target_chunk_list[i], queue)) for i in range(m)]
    for p in processes:
        p.start()
    for p in processes:
        p.join()

    results = []
    for _ in range(m):
        results.extend(queue.get())

    results = sorted(results, key=lambda x: x[0])

    txt_filename = f"{output}/{os.path.basename(args.input)}.txt"
    ave_distance = round(sum([dist for (_, dist, _) in box]) / len(box), 3)
    ave_duration = round(sum([duration for (_, _, duration) in box]) / len(box), 3)
    with open(txt_filename, "w") as f:
        print("idx dist duration", file=f)
        for i, (idx, dist, duration) in enumerate(box):
            print(idx, dist, duration, file=f)
        print("ave", ave_distance, ave_duration, file=f)
    print(txt_filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--n-agent", type=int, default=5, help="number of agents")
    parser.add_argument("--n-instance", type=int, default=3, help="number of instance")
    parser.add_argument("--output", type=str, default="mip_result", help="output-folder")
    parser.add_argument("--input", type=str, required=True, help="input-data")
    parser.add_argument("--solver", choices=SolverChoices.getValues(), default="gurobi", help="solver")
    parser.add_argument(
        "--constraint", choices=SubtourEliminationMethods.getValues(), required=True, help="subtour elimination method"
    )
    parser.add_argument(
        "--formulation", choices=["maxmin", "summin"], required=True, help="subtour elimination method"
    )
    parser.add_argument("--timelimit", type=int, default=3600, help="number of instance")
    parser.add_argument("--n-thread", type=int, default=1, help="number of instance")
    parser.add_argument("--show-log", action="store_true", help="number of instance")
    parser.add_argument("--num-process", type=int, default=1, help="number of instance")

    args = parser.parse_args()
    main(args)
