import time

import numpy as np
import pulp
from mip.mip_utils import SolverChoices
from mip.mip_utils import SubtourEliminationMethods


class MIP_CVRP_SUM:
    def __init__(
        self,
        src_vector: list[int],
        n_agent: int,
        n_node: int,
        locations: np.ndarray,
        subtour_elimination: SubtourEliminationMethods,
    ) -> None:
        self.src_vector = src_vector
        self.n_truck = len(src_vector)
        self.n_agent = n_agent
        self.n_node = n_node
        self.locations = locations
        # idx
        self.depot = 0
        self.dist = {}
        for i in range(self.n_node):
            for j in range(i + 1, self.n_node):
                d = np.linalg.norm(locations[i] - locations[j])
                self.dist[i, j] = d
                self.dist[j, i] = d
        self.subtour_elimination_methods = subtour_elimination

    def add_subtour_elimination_methods(self, problem, x, z, f):
        if self.subtour_elimination_methods == SubtourEliminationMethods.MTZ.value:
            for i in range(self.n_node):
                for j in range(self.n_node):
                    if j == self.depot or i == j:
                        continue
                    problem.addConstraint(
                        z[j] - z[i] >= 1 - (1 - x[i, j]) * self.n_node,
                        f"subtour_elimination_{i}_{j}",
                    )
        elif self.subtour_elimination_methods == SubtourEliminationMethods.single_comodity_flow.value:
            for i in range(self.n_node):
                if i == self.depot:
                    continue
                problem.addConstraint(
                    pulp.lpSum([f[j, i] for j in range(self.n_node) if i != j])
                    - pulp.lpSum([f[i, j] for j in range(self.n_node) if i != j])
                    == 1,
                    f"subtour_elimination_{i}",
                )

            for i in range(self.n_node):
                for j in range(self.n_node):
                    if i == j:
                        continue
                    problem.addConstraint(f[i, j] <= self.n_node * x[i, j], f"{i}_{j}")

        else:
            raise ValueError("input appropriate subtour elimination methods")

    def formulate(self) -> pulp.LpProblem:
        problem = pulp.LpProblem(name="cvrp", sense=pulp.LpMinimize)
        x = {
            (i, j): pulp.LpVariable(name=f"x_{i}_{j}", cat="Binary")
            for i in range(self.n_node)
            for j in range(self.n_node)
            if i != j
        }
        z = {
            i: pulp.LpVariable(name=f"z_{i}", cat="Integer", lowBound=0, upBound=self.n_node)
            for i in range(self.n_node)
        }
        f = {
            (i, j): pulp.LpVariable(name=f"f_{i}_{j}", lowBound=0, cat="Integer")
            for i in range(self.n_node)
            for j in range(self.n_node)
            if i != j
        }

        problem += pulp.lpSum(
            [self.dist[i, j] * x[i, j] for i in range(self.n_node) for j in range(self.n_node) if i != j]
        )

        for i in range(self.n_node):
            problem.addConstraint(
                pulp.lpSum([x[i, j] for j in range(self.n_node) if j != i])
                == pulp.lpSum([x[j, i] for j in range(self.n_node) if j != i]),
                f"flow_constraint_node{i}",
            )
            count = 1 if i != self.depot else self.n_truck
            problem.addConstraint(
                pulp.lpSum([x[i, j] for j in range(self.n_node) if j != i]) == count,
                f"count_node{i}",
            )

        for i in range(self.n_node):
            for j in range(self.n_node):
                if j == self.depot or i == j:
                    continue
                problem.addConstraint(
                    z[j] - z[i] >= 1 - (1 - x[i, j]) * self.n_node,
                    f"subtour_elimination_{i}_{j}",
                )
        for i in range(self.n_node):
            if i == self.depot:
                continue
            problem.addConstraint(
                z[i] <= max(self.src_vector),
                f"capacity_{i}",
            )

        self.add_subtour_elimination_methods(problem, x, z, f)

        problem.writeLP("sample.lp")

        return problem, x, z, f

    def getPuLPSolver(
        self, solver_type: SolverChoices, timelimit: int, n_thread: int, log_path: str, show_log: bool
    ) -> pulp.PULP_CBC_CMD | pulp.CPLEX_CMD | pulp.GUROBI_CMD:
        """適当なoptionをつけたsolverの情報を返す関数

        Parameters
        ----------
        solver_type : str
            cbc, cplex, gurobiのいずれかを想定
        timeLimit : int
            計算時間上限
        n_thread : int
            スレッド数
        logPath : str
            logを出力するファイルパス

        Returns
        -------
        solver
        """

        if solver_type == SolverChoices.cbc.value:
            if show_log:
                return pulp.PULP_CBC_CMD(timeLimit=timelimit, msg=True, threads=n_thread, logPath=None)
            else:
                return pulp.PULP_CBC_CMD(timeLimit=timelimit, msg=False, threads=n_thread, logPath=log_path)
        elif solver_type == SolverChoices.cplex.value:
            if show_log:
                return pulp.CPLEX_CMD(timeLimit=timelimit, msg=True, threads=n_thread, logPath=None)
            else:
                return pulp.CPLEX_CMD(timeLimit=timelimit, msg=False, threads=n_thread, logPath=log_path)
        elif solver_type == SolverChoices.gurobi.value:
            if show_log:
                return pulp.GUROBI_CMD(
                    msg=True,
                    options=[
                        ("TimeLimit", timelimit),
                        ("Threads", n_thread),
                        ("LogFile", None),
                    ],
                )
            else:
                return pulp.GUROBI_CMD(
                    msg=False,
                    options=[
                        ("TimeLimit", timelimit),
                        ("Threads", n_thread),
                        ("LogFile", log_path),
                    ],
                )
        else:
            raise ValueError("Input appropriate solver name.")

    def get_result(self, x: dict) -> list[list[int]]:
        paths = []

        next_position = {i: None for i in range(self.n_node)}
        for i in range(self.n_node):
            if i == self.depot:
                continue
            for j in range(self.n_node):
                if i == j:
                    continue
                if pulp.value(x[i, j]) > 0:
                    next_position[i] = j

        second_node_list = [j for j in range(self.n_node) if j != self.depot if pulp.value(x[self.depot, j]) > 0]

        for current in second_node_list:
            path = [self.depot, current]
            while next_position[current] != self.depot:
                path.append(next_position[current])
                current = next_position[current]
            path.append(self.depot)
            paths.append(path)

        return paths

    def __call__(self, solver_type, timelimit, n_thread, log_path, show_log):
        problem, x, z, f = self.formulate()
        solver = self.getPuLPSolver(solver_type, timelimit, n_thread, log_path, show_log)
        begin = time.time()
        status = problem.solve(solver)

        # for i in range(self.n_node):
        #     if i == self.depot:
        #         continue
        #     print(i, pulp.value(z[i]))

        duration = time.time() - begin
        if pulp.LpStatus[status] not in {"Optimal", "Not Solved"}:
            raise ValueError("cannot find the optimal solution")
        result = self.get_result(x)
        return result, duration
