import random

random.seed(0)
import time
from itertools import combinations
from itertools import permutations

import numpy as np
from mm_cvrp.policy import get_cost
from tqdm import tqdm


class TripDivision:
    def __init__(
        self,
        src_vector: list[int],
        n_agent: int,
        n_node: int,
        locations: np.ndarray,
        depot: int = 0,
        force_random_sampling: bool = False,
    ) -> None:
        self.src_vector = src_vector
        self.n_agent = n_agent
        self.n_node = n_node
        self.locations = locations
        self.result = [[] for _ in range(self.n_agent)]
        self.dist = self.get_dist_matrix()
        # この値を組合せ数が超えたらランダムサンプリングになる
        self.maxCombination = 10**8
        self.timeLimit = 3
        self.depot = depot
        # 強制的にアンカーをランダムサンプリングにさせる
        self.force_random_sampling = force_random_sampling

    def reset(self):
        self.result = [[] for _ in range(self.n_agent)]

    def get_dist_matrix(self):
        return [
            [np.linalg.norm(self.locations[i] - self.locations[j]) for i in range(self.n_node)]
            for j in range(self.n_node)
        ]

    def extractAnchor(self, dist, numTruck):
        """
        アンカーを抽出する関数

        Parameters
        ==========
        dist : np.array
        numTruck : int
        """
        idxList = [i for i in range(len(dist))]
        max2dist = -float("inf")
        # 組み合わせ数を計算して閾値を作った方がいい
        # truckの台数によってはcombinationのリストを保持するだけで時間がかかるので、６点をrandom pickした方がいいかも

        # idxSet内の任意の２点間の最小距離を抽出
        def pickShortDist(idxSet):
            """
            idxSet内の任意の２点間の最小距離を抽出
            """
            partCombList = list(permutations(idxSet, 2))
            minDist = min([dist[src][dst] for src, dst in partCombList])
            return minDist

        numComb = self.calculate_numComb(len(idxList), numTruck)

        if numComb >= self.maxCombination or self.force_random_sampling:
            c = 0
            checked_idxSet = set()
            start = time.time()
            while time.time() - start <= self.timeLimit:
                c += 1
                idxSet = random.sample(idxList, numTruck)
                size = pickShortDist(idxSet)
                if size >= max2dist:
                    anchor = idxSet
                    max2dist = size
                checked_idxSet.add(tuple(sorted(idxSet)))
                # flush_progress_bar(f"RandomPickAnchor for {numTruck} trucks", len(checked_idxSet), numComb)
            print(f"random sampling trial : {c} : {round(c / numComb * 100, 2)}")
        else:
            combList = list(combinations(idxList, numTruck))
            random.shuffle(combList)
            start = time.time()
            for idxSet in tqdm(combList):
                size = pickShortDist(idxSet)
                if size >= max2dist:
                    anchor = idxSet
                    max2dist = size
                if time.time() - start >= self.timeLimit:
                    break

        return anchor

    def anchorBasedPacking(self, dist, num_truck, method):
        # pic up the anchor set
        anchorList = self.extractAnchor(dist, num_truck)

        # sort items except for anchor
        if method == "anchor":  # part2
            value_func = lambda x: min([min(self.dist[anchor][x], self.dist[x][anchor]) for anchor in anchorList])
        elif method == "anchor_update":  # part3
            value_func = lambda x: -np.var(
                np.array([min(self.dist[anchor][x], self.dist[x][anchor]) for anchor in anchorList])
            )
        else:
            print("Error!")
            exit()

        normalItemList = [v for v in range(len(dist)) if v not in anchorList]
        sorted_items = sorted(normalItemList, key=value_func)

        bin_status = {}
        for anchor in anchorList:
            start_bin_status = [anchor]
            bin_status[anchor] = start_bin_status

        anchor2tID = {anchor: i for i, anchor in enumerate(anchorList)}

        # sort truck per item
        for item in sorted_items:
            sorted_truck = sorted(
                {anchor: min(self.dist[item][anchor], self.dist[anchor][item]) for anchor in anchorList}.items(),
                key=lambda x: x[1],
            )
            for anchor, _ in sorted_truck:
                tID = anchor2tID[anchor]
                if len(bin_status[anchor]) < self.src_vector[tID]:
                    bin_status[anchor].append(item)
                    itemFeasibility = True
                    break
            if itemFeasibility is False:
                raise ValueError("Not found")

        return bin_status

    def calculate_numComb(self, n, r):
        if n - r < r:
            r = n - r
        if r == 0:
            return 1
        if r == 1:
            return n

        numerator = [n - r + k + 1 for k in range(r)]
        denominator = [k + 1 for k in range(r)]

        # combinationの分母分子を比較して割れるところを割っていく
        for i in range(2, r + 1):
            target = denominator[i - 1]
            if target > 1:
                offset = (n - r) % i
                for k in range(i - 1, r, i):
                    numerator[k - offset] /= target
                    denominator[k] /= target

        # 最後に余った分子の部分を掛け合わせる
        output = 1
        for k in range(r):
            if numerator[k] > 1:
                output *= int(numerator[k])

        return output

    def __call__(self, method="anchor_update"):
        self.reset()
        bin_status = self.anchorBasedPacking(self.dist, self.n_agent, method)

        path_list = []
        path_length_list = []

        # unit = self.n_node / self.n_agent + 5
        # for i in range(self.n_agent):
        #     target_list = [j for j in range(int(unit * i), min(int(unit * (i + 1)), self.n_node))]
        #     path, path_length = solve_tsp(self.dist, target_list, self.depot)
        #     path_list.append(path)
        #     path_length_list.append(path_length)
        # return path_list, path_length_list
        action = []
        bin_status_values = [v for v in bin_status.values()]
        local2global = {}
        counter = {i: 1 for i in range(self.n_agent)}
        for node in range(1, self.n_node + 1):
            for i in range(len(bin_status_values)):
                if node in bin_status_values[i]:
                    action.append(i)
                    local2global[(i, counter[i])] = node
                    counter[i] += 1
        action = np.array([action])
        data = np.array([self.locations])
        _, path_length_list, path_list = get_cost(action, data, self.n_agent, return_path=True)
        path_length_list = path_length_list[0]
        path_list = path_list[0]
        for t_idx in range(len(path_list)):
            for l_idx in range(len(path_list[t_idx])):
                if path_list[t_idx][l_idx] == 0:
                    continue
                path_list[t_idx][l_idx] = local2global[(t_idx, path_list[t_idx][l_idx])]

        # for _, target_list in bin_status.items():
        #     if self.depot in target_list:
        #         target_list.remove(self.depot)
        #     path, path_length = solve_tsp(self.dist, target_list, self.depot)
        #     path_list.append(path)
        #     path_length_list.append(path_length)

        return path_list, path_length_list
