import time

import numpy as np
import pulp
from mip.mip_utils import SolverChoices
from mip.mip_utils import SubtourEliminationMethods


class MIP_CVRP:
    def __init__(
        self,
        src_vector: list[int],
        n_agent: int,
        n_node: int,
        locations: np.ndarray,
        subtour_elimination: SubtourEliminationMethods,
    ) -> None:
        self.src_vector = src_vector
        self.n_truck = len(src_vector)
        self.n_agent = n_agent
        self.n_node = n_node
        self.locations = locations
        # idx
        self.depot = 0
        self.dist = {}
        for i in range(self.n_node):
            for j in range(i + 1, self.n_node):
                d = np.linalg.norm(locations[i] - locations[j])
                self.dist[i, j] = d
                self.dist[j, i] = d
        self.k = 5
        self.subtour_elimination_methods = subtour_elimination
        self.succs = self.get_edge()
        self.preds = {i: set([j for j in range(self.n_node) if i in self.succs[j]]) for i in range(self.n_node)}

    def get_edge(self):
        edges = {}
        for i in range(self.n_node):
            if i == self.depot:
                edges[i] = set([j for j in range(self.n_node) if j != self.depot])
                continue
            dist = {j: self.dist[i, j] for j in range(self.n_node) if i != j and j != self.depot}
            sorted_dist = sorted(dist.items(), key=lambda x: x[1])
            edge = [i for (i, _) in sorted_dist[: self.k]]
            edges[i] = set(edge) | set([self.depot])

        return edges

    def add_subtour_elimination_methods(self, problem, x, z, f):
        if self.subtour_elimination_methods == SubtourEliminationMethods.MTZ.value:
            for t in range(self.n_truck):
                for i in range(self.n_node):
                    for j in self.succs[i]:
                        if j == self.depot:
                            continue
                        problem.addConstraint(
                            z[j] - z[i] >= 1 - (1 - x[i, j, t]) * self.n_node,
                            f"subtour_elimination_{i}_{j}_truck{t}",
                        )
        elif self.subtour_elimination_methods == SubtourEliminationMethods.single_comodity_flow.value:
            for i in range(self.n_node):
                if i == self.depot:
                    continue
                problem.addConstraint(
                    pulp.lpSum([f[j, i, t] for j in self.preds[i] for t in range(self.n_truck)])
                    - pulp.lpSum([f[i, j, t] for j in self.succs[i] for t in range(self.n_truck)])
                    == 1,
                    f"subtour_elimination_{i}",
                )

            for t in range(self.n_truck):
                for i in range(self.n_node):
                    for j in self.succs[i]:
                        problem.addConstraint(f[i, j, t] <= self.n_node * x[i, j, t], f"{i}_{j}_{t}")

        else:
            raise ValueError("input appropriate subtour elimination methods")

    def formulate(self) -> pulp.LpProblem:
        problem = pulp.LpProblem(name="cvrp", sense=pulp.LpMinimize)
        x = {
            (i, j, t): pulp.LpVariable(name=f"x_{i}_{j}_{t}", cat="Binary")
            for i in range(self.n_node)
            for j in self.succs[i]
            for t in range(self.n_truck)
        }
        z = {
            i: pulp.LpVariable(name=f"z_{i}", cat="Integer", lowBound=0, upBound=self.n_node)
            for i in range(self.n_node)
        }
        f = {
            (i, j, t): pulp.LpVariable(name=f"f_{i}_{j}_{t}", lowBound=0, cat="Integer")
            for i in range(self.n_node)
            for j in self.succs[i]
            for t in range(self.n_node)
        }
        cost = pulp.LpVariable(name="cost", lowBound=-1)

        problem += cost

        for t in range(self.n_truck):
            problem.addConstraint(
                cost >= pulp.lpSum([self.dist[i, j] * x[i, j, t] for i in range(self.n_node) for j in self.succs[i]])
            )

        for i in range(self.n_node):
            for t in range(self.n_truck):
                problem.addConstraint(
                    pulp.lpSum([x[i, j, t] for j in self.succs[i]]) == pulp.lpSum([x[j, i, t] for j in self.preds[i]]),
                    f"flow_constraint_node{i}_truck{t}",
                )

        for i in range(self.n_node):
            if i == self.depot:
                continue
            problem.addConstraint(
                pulp.lpSum([x[i, j, t] for j in self.succs[i] for t in range(self.n_truck)]) == 1,
                f"no_double_booking_node{i}_truck{t}",
            )

        for t in range(self.n_truck):
            problem.addConstraint(
                pulp.lpSum([x[j, self.depot, t] for j in self.preds[self.depot]]) == 1,
                f"depot_constraint_out_truck{t}",
            )
            problem.addConstraint(
                pulp.lpSum([x[self.depot, j, t] for j in self.succs[self.depot]]) == 1,
                f"depot_constraint_in_truck{t}",
            )

        self.add_subtour_elimination_methods(problem, x, z, f)

        for t in range(self.n_truck):
            # +2 : depotから出発、depotに到着
            problem.addConstraint(
                pulp.lpSum([x[i, j, t] for i in range(self.n_node) for j in self.succs[i]]) <= self.src_vector[t] + 2,
                f"capacity_constraint_truck_{t}",
            )

        return problem, x, z, f

    def getPuLPSolver(
        self, solver_type: SolverChoices, timelimit: int, n_thread: int, log_path: str, show_log: bool
    ) -> pulp.PULP_CBC_CMD | pulp.CPLEX_CMD | pulp.GUROBI_CMD:
        """適当なoptionをつけたsolverの情報を返す関数

        Parameters
        ----------
        solver_type : str
            cbc, cplex, gurobiのいずれかを想定
        timeLimit : int
            計算時間上限
        n_thread : int
            スレッド数
        logPath : str
            logを出力するファイルパス

        Returns
        -------
        solver
        """

        if solver_type == SolverChoices.cbc.value:
            if show_log:
                return pulp.PULP_CBC_CMD(timeLimit=timelimit, msg=True, threads=n_thread, logPath=None)
            else:
                return pulp.PULP_CBC_CMD(timeLimit=timelimit, msg=False, threads=n_thread, logPath=log_path)
        elif solver_type == SolverChoices.cplex.value:
            if show_log:
                return pulp.CPLEX_CMD(timeLimit=timelimit, msg=True, threads=n_thread, logPath=None)
            else:
                return pulp.CPLEX_CMD(timeLimit=timelimit, msg=False, threads=n_thread, logPath=log_path)
        elif solver_type == SolverChoices.gurobi.value:
            if show_log:
                return pulp.GUROBI_CMD(
                    msg=True,
                    options=[
                        ("TimeLimit", timelimit),
                        ("Threads", n_thread),
                        ("LogFile", None),
                    ],
                )
            else:
                return pulp.GUROBI_CMD(
                    msg=False,
                    options=[
                        ("TimeLimit", timelimit),
                        ("Threads", n_thread),
                        ("LogFile", log_path),
                    ],
                )
        else:
            raise ValueError("Input appropriate solver name.")

    def get_result(self, x: dict) -> list[list[int]]:
        paths = []

        for t in range(self.n_truck):
            next_position = {i: None for i in range(self.n_node)}
            for i in range(self.n_node):
                for j in self.succs[i]:
                    if pulp.value(x[i, j, t]) > 0:
                        next_position[i] = j
            current = self.depot
            path = [self.depot]
            while next_position[current] != self.depot:
                path.append(next_position[current])
                current = next_position[current]
            path.append(self.depot)
            paths.append(path)

        return paths

    def __call__(self, solver_type, timelimit, n_thread, log_path, show_log):
        problem, x, z, f = self.formulate()
        solver = self.getPuLPSolver(solver_type, timelimit, n_thread, log_path, show_log)
        begin = time.time()
        status = problem.solve(solver)
        duration = time.time() - begin
        if pulp.LpStatus[status] not in {"Optimal", "Not Solved"}:
            return -1, -1
            raise ValueError("cannot find the optimal solution")
        result = self.get_result(x)
        return result, duration
