import os
import random

random.seed(0)

import matplotlib.pyplot as plt
import numpy as np


class LocalSearch:
    def __init__(
        self,
        capacity: int,
        n_agent: int,
        n_node: int,
        n_iter: int,
        locations: np.ndarray,
        output_folder: str = "sample",
        disable_plot: bool = False,
    ) -> None:
        self.capacity = capacity
        self.n_agent = n_agent
        self.n_node = n_node
        self.n_iter = n_iter
        self.locations = locations
        self.dist = np.zeros([self.n_node, self.n_node])
        for i in range(self.n_node):
            for j in range(i, self.n_node):
                d = np.linalg.norm(self.locations[i] - self.locations[j])
                self.dist[i][j] = d
                self.dist[j][i] = d

        self.output_folder = output_folder
        self.disable_plot = disable_plot
        if not self.disable_plot:
            os.makedirs(self.output_folder, exist_ok=True)

    def plot(self, iter, path_list, path_length_list):
        if self.disable_plot:
            return

        plt.figure(figsize=(10, 10))
        plt.scatter(self.locations[:, 0], self.locations[:, 1], s=10, c="gray")
        cmap = plt.get_cmap("tab10")

        dist_list = []
        for i, path in enumerate(path_list):
            dist = 0
            flag = True
            for src, dst in zip(path, path[1:], strict=False):
                dist += np.linalg.norm(self.locations[src] - self.locations[dst])
                plt.plot(
                    [self.locations[src, 0], self.locations[dst, 0]],
                    [self.locations[src, 1], self.locations[dst, 1]],
                    color=cmap(i % 10),
                    label=f"{len(path) - 2} : {round(path_length_list[i], 3)}" if flag else None,
                )
                plt.text(self.locations[src, 0], self.locations[src, 1], src, color=cmap(i % 10))
                flag = False
            dist += np.linalg.norm(self.locations[dst] - self.locations[0])
            dist_list.append(dist)

        plt.legend()
        plt.title(max(path_length_list))
        filename = f"{self.output_folder}/{str(iter).zfill(3)}.png"
        plt.savefig(filename)
        print(filename)
        plt.close()

    def __call__(self, path_list, path_length_list):
        # FIXME : capacity check
        path_length_list = np.array(path_length_list)
        c = 0
        self.plot(c, path_list, path_length_list)

        for iter in range(self.n_iter):
            target_trip_idx = np.argmax(path_length_list)
            max_trial = 10
            for trial in range(1, max_trial + 1):
                target_node_idx = random.choice([j for j in range(1, len(path_list[target_trip_idx]) - 1)])
                save_length = (
                    self.dist[path_list[target_trip_idx][target_node_idx - 1]][
                        path_list[target_trip_idx][target_node_idx]
                    ]
                    + self.dist[path_list[target_trip_idx][target_node_idx]][
                        path_list[target_trip_idx][target_node_idx + 1]
                    ]
                    - self.dist[path_list[target_trip_idx][target_node_idx - 1]][
                        path_list[target_trip_idx][target_node_idx + 1]
                    ]
                )
                inserted_trip_idx = None
                inserted_node_idx = None
                tmp_max_length = path_length_list[target_trip_idx]
                record_gain_length = 0
                for t_idx in range(self.n_agent):
                    if t_idx == target_trip_idx:
                        continue
                    # -2 : depot分
                    if len(path_list[t_idx]) - 2 == self.capacity:
                        continue
                    for node_idx in range(1, len(path_list[t_idx]) - 1):
                        gain_length = (
                            self.dist[path_list[t_idx][node_idx]][path_list[target_trip_idx][target_node_idx]]
                            + self.dist[path_list[target_trip_idx][target_node_idx]][path_list[t_idx][node_idx + 1]]
                            - self.dist[path_list[t_idx][node_idx]][path_list[t_idx][node_idx + 1]]
                        )
                        # print(
                        #     path_list[target_trip_idx][target_node_idx],
                        #     path_list[t_idx][node_idx],
                        #     max(
                        #         path_length_list[target_trip_idx] - save_length,
                        #         path_length_list[t_idx] + gain_length,
                        #     )
                        #     < tmp_max_length,
                        #     round(path_length_list[target_trip_idx] - save_length, 3),
                        #     round(path_length_list[t_idx] + gain_length, 3),
                        #     round(tmp_max_length, 3),
                        #     node_idx,
                        # )
                        if (
                            max(
                                path_length_list[target_trip_idx] - save_length,
                                path_length_list[t_idx] + gain_length,
                            )
                            < tmp_max_length
                        ):
                            inserted_trip_idx = t_idx
                            inserted_node_idx = node_idx
                            record_gain_length = gain_length
                            tmp_max_length = max(
                                path_length_list[t_idx] + gain_length, path_length_list[target_trip_idx] - save_length
                            )
                if inserted_node_idx is not None:
                    target_node = path_list[target_trip_idx][target_node_idx]
                    path_list[target_trip_idx].remove(target_node)
                    path_list[inserted_trip_idx].insert(inserted_node_idx + 1, target_node)
                    path_length_list[target_trip_idx] -= round(save_length, 3)
                    path_length_list[inserted_trip_idx] += round(record_gain_length, 3)
                    self.plot(iter, path_list, path_length_list)
                    c += 1
                    inserted_trip_idx = None
                    inserted_node_idx = None
                    break

        return path_list, path_length_list
