import argparse
import math
import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import random
import time

import matplotlib as mpl
import matplotlib.pyplot as plt
import ot
import pulp
import seaborn as sns
import torch
import torch_geometric
from cli.lkh_preprocess import write_parameter
from cli.lkh_preprocess import write_problem
from heuristics.local_search import LocalSearch
from LinSATNet import linsat_layer
from mm_cvrp.policy import Policy
from mm_cvrp.policy import action_sample
from mm_cvrp.policy import get_cost4plot
from mm_cvrp.trainer import MapDataset
from mm_cvrp.utils import get_src_vector
from torch.utils.data import DataLoader
from torch_geometric.data import Batch
from torch_geometric.data import Data
from utils.capacity import node2capacity


def get_clustering_result(
    model: Policy,
    dataset: torch_geometric.data.Batch,
    device: str,
    data: torch.Tensor,
    n_node: int,
    n_agent: int,
    ignore_ot: bool = False,
    naive_flag: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
    def calculate_optimal_transport():
        # 各地点の割当先を求めたいので、１が並んだベクトルを用意
        # -1はデポ分
        src_vector = torch.tensor(get_src_vector(n_node - 1, n_agent), dtype=torch.float32)
        dst_vector = torch.ones(cost.shape[2], dtype=torch.float32)
        # piの各行が1-hotではなく割れた時に非零要素のみを対象に解き直して割り当てを求めても良い(別の手法を使ってもいい)
        cost[0] = 1 - cost[0]
        pi = ot.emd(src_vector, dst_vector, cost[0])
        num_assigined_vector = torch.sum(pi, axis=1)
        if any(num_assigined_vector > src_vector):
            raise ValueError("capacity constraint is not satisfied")

        # sample action and calculate log probabilities
        action, _ = action_sample(pi.unsqueeze(0))

        return action, cost, pi, num_assigined_vector

    def calculate_optimal_transport_linsatnet():
        pi_tmp = cost[0].reshape(1, -1)
        cap = node2capacity[n_node]

        # capacity constraint
        A = torch.zeros([n_agent, n_agent * (n_node - 1)], device=device)
        b = torch.tensor([cap] * n_agent, dtype=torch.float32, device=device)
        for i in range(n_agent):
            A[i, i * (n_node - 1) : (i + 1) * (n_node - 1)] = 1

        d = torch.tensor([int(cap * 0.5)] * n_agent, dtype=torch.float32, device=device)

        # １つは必ずアサインされる(総和が１)
        # the column constrain
        E = torch.zeros([n_node - 1, n_agent * (n_node - 1)], device=device)
        column_gap = (n_node - 1) * torch.arange(n_agent)
        for i in range(n_node - 1):
            E[i, i + column_gap] = 1
        f = torch.ones(n_node - 1, dtype=torch.float32, device=device)

        import time

        # noise = torch.rand(pi_tmp.shape, device=device) / 100
        # pi_tmp += noise

        start = time.time()
        # output = linsat_layer(pi_tmp, A=A, b=b, C=A, d=d, E=E, f=f, tau=1e-5, max_iter=100000)
        output = linsat_layer(pi_tmp, A=A, b=b, C=A, d=d, E=E, f=f, tau=1e-5, max_iter=1000)
        end = time.time()
        print(end - start)
        pi = output.reshape(n_agent, n_node - 1)

        num_assigined_vector = torch.sum(pi, axis=1)
        if any(num_assigined_vector > b):
            print("warning : capacity constraint is not satisfied")

        target = num_assigined_vector.cpu().detach().numpy()
        prob = pulp.LpProblem(name="sample")
        x = [pulp.LpVariable(name=f"x{i}", cat="Integer", lowBound=1, upBound=cap) for i in range(len(b))]
        l = [pulp.LpVariable(name=f"l{i}", cat="Continous", lowBound=0) for i in range(len(b))]
        prob += pulp.lpSum(l)

        for i in range(len(b)):
            prob.addConstraint(l[i] >= -x[i] + target[i])
            prob.addConstraint(l[i] >= x[i] - target[i])

        prob.addConstraint(pulp.lpSum(x) == n_node - 1)
        solver = pulp.PULP_CBC_CMD(msg=False)
        prob.solve(solver)

        src_vector = torch.tensor([pulp.value(x[i]) for i in range(len(x))])
        print(src_vector)
        print(num_assigined_vector.cpu().detach().numpy())

        dst_vector = torch.ones(cost.shape[2], dtype=torch.float32)
        # piの各行が1-hotではなく割れた時に非零要素のみを対象に解き直して割り当てを求めても良い(別の手法を使ってもいい)
        cost[0] = 1 - cost[0]
        # try:
        pi = ot.emd(src_vector, dst_vector, cost[0])
        # except:
        # breakpoint()
        num_assigined_vector = torch.sum(pi, axis=1)
        if any(num_assigined_vector > src_vector):
            raise ValueError("capacity constraint is not satisfied")
        #######################

        # sample action and calculate log probabilities
        action, _ = action_sample(pi.unsqueeze(0))

        # return action, cost, pi.cpu().detach(), num_assigined_vector.cpu().detach()
        return action, cost, pi.cpu().detach(), src_vector

    model.to(device)
    model.eval()

    # to batch graph
    adj = torch.ones([dataset.shape[0], dataset.shape[1], dataset.shape[1]])  # adjacent matrix fully connected
    data_list = [
        Data(x=dataset[i], edge_index=torch.nonzero(adj[i], as_tuple=False).t(), as_tuple=False)
        for i in range(dataset.shape[0])
    ]
    batch_graph = Batch.from_data_list(data_list=data_list).to(device)

    cost = model(batch_graph, n_nodes=dataset.shape[1], n_batch=dataset.shape[0])

    # def calculate_optimal_transport(pi: torch.Tensor, n_nodes) -> torch.Tensor:
    #     pi_tmp = pi.reshape(len(pi), -1)

    #     # capacity constraint
    #     A = torch.ones([no_agent, no_agent * (n_nodes - 1)], device=device)
    #     b = torch.tensor([25] * no_agent, dtype=torch.float32, device=device)

    #     # １つは必ずアサインされる(総和が１)
    #     # the column constrain
    #     E = torch.zeros([n_nodes - 1, no_agent * (n_nodes - 1)], device=device)
    #     column_gap = (n_nodes - 1) * torch.arange(no_agent)
    #     for i in range(n_nodes - 1):
    #         E[i, i + column_gap] = 1
    #     f = torch.ones(n_nodes - 1, dtype=torch.float32, device=device)

    #     output = linsat_layer(pi_tmp, A=A, b=b, E=E, f=f, max_iter=1000)
    #     result = output.reshape(batch_size, no_agent, n_nodes - 1)
    #     return result

    if not ignore_ot:
        action, cost, pi, num_assigined_vector = calculate_optimal_transport_linsatnet()
    else:
        # originalの手法で推論したい時に使う
        action, _ = action_sample(cost)
        pi = None
        num_assigined_vector = None

    if naive_flag:
        conditions = torch.tensor([0.2, 0.4, 0.6, 0.8])
        result = torch.sum(data[:, 0, None] > conditions, dim=1)
        # 0index should be removed because it is for the depot
        action = result[1:].unsqueeze(0)

    return action, cost, pi, num_assigined_vector


def get_naive_clustering_result(
    data: torch.Tensor,
    n_node: int,
    n_agent: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
    cluster_vector = torch.rand(size=[n_agent, 2])
    cost = cluster_vector @ torch.transpose(data[1:], 0, 1).unsqueeze(0)

    src_vector = torch.tensor(get_src_vector(n_node - 1, n_agent), dtype=torch.float32)
    dst_vector = torch.ones(cost.shape[2], dtype=torch.float32)
    # piの各行が1-hotではなく割れた時に非零要素のみを対象に解き直して割り当てを求めても良い(別の手法を使ってもいい)
    pi = ot.emd(src_vector, dst_vector, 1 - cost[0])

    action, _ = action_sample(pi.unsqueeze(0))

    return action, cost, pi


def plot(
    data: torch.Tensor,
    subtour_order: list[list[int]],
    local_idx2global_idx: list[dict[tuple[int, int], int]],
    sub_tour_length: list[float],
    cost: torch.Tensor,
    pi: torch.Tensor,
    num_assigned_vector: torch.Tensor,
    subtour_length_list: list[list[float]],
    test_data: str,
    pretrained_path: str,
    output_folder: str,
    ignore_ot: bool,
    idx: int,
    after_edited: bool = False,
):
    # plt.figure(figsize=(20, 10))
    plt.figure(figsize=(9, 9))
    ax1 = plt
    # ax1 = plt.subplot2grid((2, 2), (0, 0), rowspan=2)
    # ax2 = plt.subplot2grid((2, 2), (0, 1))
    # ax3 = plt.subplot2grid((2, 2), (1, 1))

    gap = 0.05
    # ax1.xlim([0 - gap, 1 + gap])
    # ax1.ylim([0 - gap, 1 + gap])
    # ax1.set_xlim([0 - gap, 1 + gap])
    # ax1.set_ylim([0 - gap, 1 + gap])
    cmap = plt.get_cmap("tab10")
    ax1.scatter(data[:, 0], data[:, 1], s=30, c="gray")
    # ax1.scatter(data[12, 0], data[12, 1], s=500, c="red", marker="*")
    tour_length = subtour_length_list[0]
    plotted_path_set = set([])
    depot = 0

    # +1: depotはcost行列に含まれないので
    # ambiguous_nodes = [i + 1 for i in range(cost.shape[2]) if min(cost.transpose(1, 2)[0][i]) >= 0.1]
    # print(f"ambiguous node 0.9 : {ambiguous_nodes}")
    # ax1.scatter(data[ambiguous_nodes, 0], data[ambiguous_nodes, 1], s=50, c="blue", marker="*")
    # ambiguous_nodes = [i + 1 for i in range(cost.shape[2]) if min(cost.transpose(1, 2)[0][i]) >= 0.2]
    # print(f"ambiguous node 0.8 : {ambiguous_nodes}")
    # ax1.scatter(data[ambiguous_nodes, 0], data[ambiguous_nodes, 1], s=50, c="red", marker="*")

    ambiguous_nodes = torch.argsort(cost[0].sum(dim=0))[-20:].cpu()
    # ambiguous_nodes = [i + 1 for i in range(cost.shape[2]) if min(cost.transpose(1, 2)[0][i]) >= 0.1]
    # print(f"ambiguous node 0.9 : {ambiguous_nodes}")
    # ax1.scatter(data[ambiguous_nodes, 0], data[ambiguous_nodes, 1], s=50, c="blue", marker="v")
    ambiguous_nodes = torch.argsort(cost[0].sum(dim=0))[-10:].cpu()
    # ambiguous_nodes = [i + 1 for i in range(cost.shape[2]) if min(cost.transpose(1, 2)[0][i]) >= 0.2]
    # print(f"ambiguous node 0.8 : {ambiguous_nodes}")
    # ax1.scatter(data[ambiguous_nodes, 0], data[ambiguous_nodes, 1], s=50, c="red", marker="v")

    for j in range(len(subtour_order[0])):
        path_idx = j % 10
        for src, dst in zip(subtour_order[0][j], subtour_order[0][j][1:], strict=False):
            # 0-index is for the depot
            # if src != depot:
            # src = local_idx2global_idx[0][(j, src)]
            # -1 : heatmapのidxと揃えるため(heatmapにdepotの情報は出力されていない)
            # ax1.text(data[src, 0], data[src, 1], str(src - 1))
            # if dst != depot:
            # dst = local_idx2global_idx[0][(j, dst)]
            # -2 : depot2回分
            num_location = len(subtour_order[0][j]) - 2
            ax1.plot(
                [data[src, 0], data[dst, 0]],
                [data[src, 1], data[dst, 1]],
                color=cmap(path_idx),
                label=f"{num_location} : {tour_length[path_idx]}" if path_idx not in plotted_path_set else None,
            )
            plotted_path_set.add(path_idx)
    # ax1.set_title(sub_tour_length)
    ax1.title(sub_tour_length)
    ax1.legend()

    # ax2.set_title("cost")
    # sns.heatmap(
    #     cost[0].cpu().detach().numpy(),
    #     ax=ax2,
    #     cmap="Blues",
    #     yticklabels=num_assigned_vector.numpy() if num_assigned_vector != None else "auto",
    # )

    # ax3.set_title("assignment")
    # if type(pi) == torch.Tensor:
    #     sns.heatmap(
    #         pi.cpu(),
    #         ax=ax3,
    #         cmap="Blues",
    #         yticklabels=num_assigned_vector.numpy() if num_assigned_vector != None else "auto",
    #     )

    os.makedirs(output_folder, exist_ok=True)
    if not after_edited:
        filepath = f"{output_folder}/{pretrained_path.replace('/', ':', -1)}_{test_data.replace('/', ':', -1)}_ignoreOT{ignore_ot}_{str(idx).zfill(3)}.png"
    else:
        filepath = f"{output_folder}/LS_{pretrained_path.replace('/', ':', -1)}_{test_data.replace('/', ':', -1)}_ignoreOT{ignore_ot}_{str(idx).zfill(3)}.png"
    plt.savefig(filepath)
    plt.close()
    print(filepath)


def write_problem4LKH(
    subtour_order: list[list[int]],
    locations: torch.tensor,
    n_node: int,
    n_agent: int,
    output: str,
    solutionfile_base: str,
    capacity: int,
    timelimit: int,
    total_timelimit: int,
):
    def write_imcumbent(filepath):
        # ファイルパスを指定
        box = [1]
        # print(subtour_order)
        for i in range(len(subtour_order[0])):
            for j in range(1, len(subtour_order[0][i]) - 1):
                box.append(subtour_order[0][i][j] + 1)
            if i != len(subtour_order[0]) - 1:
                box.append(n_node + i + 1)
        # print(box)

        assert n_agent + n_node - 1 == len(box)

        with open(filepath, "w") as f:
            print("TYPE : TOUR", file=f)
            print(f"DIMENSION : {len(box)}", file=f)
            print("TOUR_SECTION", file=f)
            for i in box:
                print(i, file=f)
            print(-1, file=f)
            print("EOF", file=f)

    output = f"{output}_{total_timelimit}s"
    os.makedirs(output, exist_ok=True)
    filepath = f"{output}/{solutionfile_base}.tsp"
    problem_name = f"{solutionfile_base}"
    # print(filepath)
    write_problem(locations, n_node, n_agent, problem_name, filepath, capacity)
    filepath = f"{output}/{solutionfile_base}.par"
    unit = n_node // n_agent
    initial_tour_filepath = f"{output}/{solutionfile_base}_imcumbnet.txt"
    write_parameter(filepath, problem_name, unit, output, timelimit, initial_tour_filepath)
    # print(filepath)
    write_imcumbent(initial_tour_filepath)


def main(args: argparse.Namespace):
    # unstaged_changes = check_no_uncommitted_changes()
    # if len(unstaged_changes) > 0:
    #     for line in unstaged_changes:
    #         print(line)
    #     raise ValueError("unstaged changes are detected!!")

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    import torch

    dev = "cuda" if torch.cuda.is_available() else "cpu"
    n_agent = args.n_agent
    if n_agent > 10:
        raise ValueError("colormap must be modified")
    policy = Policy(
        in_chnl=2,
        hid_chnl=64,
        n_agent=n_agent,
        key_size_embd=32,
        key_size_policy=128,
        val_size=16,
        clipping=10,
        dev=dev,
        disable_softmax="softmaxFalse" in args.pretrained_path,
    )
    test_dataset = MapDataset(folder_path=args.test_data, mode="test", return_depot=True, augmentation=False)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
    policy.load_state_dict(torch.load(args.pretrained_path, map_location=torch.device(dev)))

    cap = node2capacity[args.n_node]

    box = []
    ls_box = []
    timebox = []
    ls_timebox = []
    c = 0
    total_timelimit = args.timelimit
    for batched_data, depots in test_dataloader:
        for i in range(len(batched_data)):
            data = batched_data[i]
            begin = time.time()
            if not args.naive:
                action, cost, pi, num_assigned_vector = get_clustering_result(
                    policy, data.unsqueeze(0), dev, data, args.n_node, args.n_agent, ignore_ot=args.ignore_ot
                )
            else:
                action, cost, pi, num_assigned_vector = get_naive_clustering_result(data, args.n_node, args.n_agent)
            sub_tour_length, subtour_order, subtour_length_list, local_idx2global_idx = get_cost4plot(
                action, data.unsqueeze(0), n_agent
            )  # reward: tensor [batch, 1]
            end = time.time()
            timebox.append(end - begin)
            data = data + depots[i]

            for t_idx in range(len(subtour_order[0])):
                # ignore depot
                for n_idx in range(1, len(subtour_order[0][t_idx]) - 1):
                    node = subtour_order[0][t_idx][n_idx]
                    subtour_order[0][t_idx][n_idx] = local_idx2global_idx[0][(t_idx, node)]

            if args.use_output4lkh:
                timelimit = int(total_timelimit - math.floor(timebox[-1]))
                solutionfile_base = f"{args.pretrained_path.replace('/', ':', -1)}_{args.test_data.replace('/', ':', -1)}_ignoreOT{args.ignore_ot}_{str(c).zfill(3)}"
                write_problem4LKH(
                    subtour_order,
                    data,
                    args.n_node,
                    args.n_agent,
                    args.lkh_output_folder,
                    solutionfile_base,
                    cap,
                    timelimit,
                    total_timelimit,
                )

            plot(
                data,
                subtour_order,
                local_idx2global_idx,
                sub_tour_length,
                cost,
                pi,
                num_assigned_vector,
                subtour_length_list,
                args.test_data,
                args.pretrained_path,
                args.output_folder,
                args.ignore_ot,
                c,
            )
            box.append(sub_tour_length[0])

            if args.add_local_search:
                begin = time.time()
                local_search_model = LocalSearch(
                    capacity=cap,
                    n_agent=args.n_agent,
                    n_node=args.n_node,
                    n_iter=1000,
                    locations=data,
                    disable_plot=True,
                )
                path_list, path_length_list = local_search_model(subtour_order[0], subtour_length_list[0])
                subtour_order = [path_list]
                subtour_length_list = [path_length_list]
                end = time.time()
                ls_timebox.append(end - begin + timebox[-1])

                plot(
                    data,
                    subtour_order,
                    local_idx2global_idx,
                    max(subtour_length_list[0]),
                    cost,
                    pi,
                    num_assigned_vector,
                    subtour_length_list,
                    args.test_data,
                    args.pretrained_path,
                    args.output_folder,
                    args.ignore_ot,
                    c,
                    True,
                )
                ls_box.append(max(subtour_length_list[0]))

            c += 1
            if c >= args.n_instance:
                break

    txt_filename = f"{args.output_folder}/{args.pretrained_path.replace('/', ':', -1)}_{args.test_data.replace('/',':', -1)}_ignoreOT{args.ignore_ot}.txt"
    with open(txt_filename, "w") as f:
        for i, row in enumerate(box):
            print(i, row, file=f)
        print("ave", round(sum(box) / len(box), 3), round(sum(timebox) / len(timebox), 3), file=f)
    print(txt_filename)
    print(round(sum(box) / len(box), 3))
    print("time", round(sum(timebox) / len(timebox), 3))

    if args.add_local_search:
        txt_filename = f"{args.output_folder}/LS_{args.pretrained_path.replace('/', ':', -1)}_{args.test_data.replace('/',':', -1)}_ignoreOT{args.ignore_ot}.txt"
        with open(txt_filename, "w") as f:
            for i, row in enumerate(ls_box):
                print(i, row, file=f)
            print("ave", round(sum(ls_box) / len(ls_box), 3), round(sum(ls_timebox) / len(ls_timebox), 3), file=f)
        print(txt_filename)
        print(round(sum(ls_box) / len(ls_box), 3))
        print("time", round(sum(ls_timebox) / len(ls_timebox), 3))


if __name__ == "__main__":
    manual_seed = 1
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    torch.cuda.manual_seed(manual_seed)
    torch.backends.cudnn.determinstic = True
    torch.backends.cudnn.benchmarks = False
    os.environ["PYTHONHASHSEED"] = str(manual_seed)
    mpl.rcParams["pdf.fonttype"] = 42
    mpl.rcParams["ps.fonttype"] = 42

    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=int, default=0, help="number of agents")
    parser.add_argument("--n-agent", type=int, default=5, help="number of agents")
    parser.add_argument("--n-node", type=int, default=30, help="number of agents")
    parser.add_argument("--n-instance", type=int, default=3, help="number of instance")
    parser.add_argument("--test-data", type=str, help="path to test data")
    parser.add_argument("--pretrained-path", type=str, help="path to network")
    parser.add_argument("--naive", action="store_true", help="execute naive clustering")
    parser.add_argument("--ignore-ot", action="store_true", help="ignoring optimal transportation")
    parser.add_argument("--add-local-search", action="store_true", help="ignoring optimal transportation")
    parser.add_argument("--output-folder", type=str, default="output", help="output folder")
    parser.add_argument("--use-output4lkh", action="store_true", help="output folder")
    parser.add_argument("--lkh-output-folder", type=str, default="output4lkh", help="output folder")
    parser.add_argument("--timelimit", type=int)
    main(parser.parse_args())
