import torch
import numpy as np
from enum import Enum
from itertools import cycle, islice, dropwhile
from tsp_rtdl import EdgeSelector, calculate_total_distance

class OptCase(Enum):
    opt_case_1 = "opt_case_1"
    opt_case_2 = "opt_case_2"
    opt_case_3 = "opt_case_3"
    opt_case_4 = "opt_case_4"
    opt_case_5 = "opt_case_5"
    opt_case_6 = "opt_case_6"
    opt_case_7 = "opt_case_7"
    opt_case_8 = "opt_case_8"

def get_vars2(vlist, triples):

    v_triples = [[] for _ in range(len(vlist))]

    for i, triple in enumerate(triples):
        v_triples[triple[0]].append(i)
        v_triples[triple[1]].append(i)
        v_triples[triple[2]].append(i)

    return v_triples

def possible_segments(N):
    """ Generate the combination of segments """
    segments = ((i, j, k) for i in range(N) for j in range(i + 2, N-1) for k in range(j + 2, N - 1 + (i > 0)))
    return segments

def ranked_iter(vlist, N):

    checked = np.zeros(N ** 3, dtype = np.int32)

    for v1 in vlist:
        for v2 in vlist:
            for v3 in vlist:
                m1, m2, m3 = sorted([v1, v2, v3])

                if m2 < m1 + 2:
                    continue

                if m2 >= N - 1:
                    continue

                if m3 < m2 + 2:
                    continue

                if m3 >= N - 1 + (m1 > 0):
                    continue

                idx = m1 + m2 * N + m3 * (N**2)

                if checked[idx]:
                    continue
                else:
                    checked[idx] = 1
                    yield (m1, m2, m3)

def get_solution_cost_change(graph, route, case, i, j, k):
    """ Compare current solution with 7 possible 3-opt moves"""
    A, B, C, D, E, F = route[i - 1], route[i], route[j - 1], route[j], route[k - 1], route[k % len(route)]
    if case == OptCase.opt_case_1:
        # first case is the current solution ABC
        return 0
    elif case == OptCase.opt_case_2:
        # second case is the case A'BC
        # [i-1,i] removed
        return graph[A, B] + graph[E, F] - (graph[B, F] + graph[A, E])
    elif case == OptCase.opt_case_3:
        # ABC'
        return graph[C, D] + graph[E, F] - (graph[D, F] + graph[C, E])
    elif case == OptCase.opt_case_4:
        # A'BC'
        # [i-1,i] removed
        return graph[A, B] + graph[C, D] + graph[E, F] - (graph[A, D] + graph[B, F] + graph[E, C])
    elif case == OptCase.opt_case_5:
        # A'B'C
        # [i-1,i] removed
        return graph[A, B] + graph[C, D] + graph[E, F] - (graph[C, F] + graph[B, D] + graph[E, A])
    elif case == OptCase.opt_case_6:
        # AB'C
        return graph[B, A] + graph[D, C] - (graph[C, A] + graph[B, D])
    elif case == OptCase.opt_case_7:
        # AB'C'
        return graph[A, B] + graph[C, D] + graph[E, F] - (graph[B, E] + graph[D, F] + graph[C, A])
    elif case == OptCase.opt_case_8:
        # A'B'C
        return graph[A, B] + graph[C, D] + graph[E, F] - (graph[A, D] + graph[C, F] + graph[B, E])

def reverse_segments(route, case, i, j, k):
    """
    Create a new tour from the existing tour
    Args:
        route: existing tour
        case: which case of opt swaps should be used
        i:
        j:
        k:

    Returns:
        new route
    """
    if (i - 1) < (k % len(route)):
        first_segment = route[k% len(route):] + route[:i]
    else:
        first_segment = route[k % len(route):i]
    second_segment = route[i:j]
    third_segment = route[j:k]

    if case == OptCase.opt_case_1:
        # first case is the current solution ABC
        pass
    elif case == OptCase.opt_case_2:
        # A'BC
        # [i-1,i] removed
        solution = list(reversed(first_segment)) + second_segment + third_segment
    elif case == OptCase.opt_case_3:
        # ABC'
        solution = first_segment + second_segment + list(reversed(third_segment))
    elif case == OptCase.opt_case_4:
        # A'BC'
        # [i-1,i] removed
        solution = list(reversed(first_segment)) + second_segment + list(reversed(third_segment))
    elif case == OptCase.opt_case_5:
        # A'B'C
        # [i-1,i] removed
        solution = list(reversed(first_segment)) + list(reversed(second_segment)) + third_segment
    elif case == OptCase.opt_case_6:
        # AB'C
        # [i-1,i] removed
        solution = first_segment + list(reversed(second_segment)) + third_segment
    elif case == OptCase.opt_case_7:
        # AB'C'
        # [i-1,i] removed
        solution = first_segment + list(reversed(second_segment)) + list(reversed(third_segment))
    elif case == OptCase.opt_case_8:
        # A'B'C
        # [i-1,i] removed
        solution = list(reversed(first_segment)) + list(reversed(second_segment)) + list(reversed(third_segment))
    return solution

def three_opt(graph, route=None, logger = None):
    #if route is None:
    #    route = christofides_tsp(graph)

    moves_cost = {OptCase.opt_case_1: 0, OptCase.opt_case_2: 0,
                  OptCase.opt_case_3: 0, OptCase.opt_case_4: 0, OptCase.opt_case_5: 0,
                  OptCase.opt_case_6: 0, OptCase.opt_case_7: 0, OptCase.opt_case_8: 0}

    improved = True
    best_found_route = route
    iterations = 0
    attempts_cnt = 0
    best_distance = calculate_total_distance(best_found_route, graph)

    while improved:
        improved = False
        for (i, j, k) in possible_segments(len(graph)):
            # we check all the possible moves and save the result into the dict
            for opt_case in OptCase:
                attempts_cnt += 1
                moves_cost[opt_case] = get_solution_cost_change(graph, best_found_route, opt_case, i, j, k)
            # we need the minimum value of substraction of old route - new route
            best_return = max(moves_cost, key=moves_cost.get)

            if moves_cost[best_return] > 0:
                best_found_route = reverse_segments(best_found_route, best_return, i, j, k)
                improved = True
                iterations += 1
                best_distance -= moves_cost[best_return]
                logger.iteration_finished(best_distance, attempts_cnt)
                #print('logger.distances', logger.distances)
                attempts_cnt = 0
                print('improved, iterations', iterations)
                break

    print(attempts_cnt)
    logger.iteration_finished(best_distance, attempts_cnt)
    # just to start with the same node -> we will need to cycle the results.
    cycled = cycle(best_found_route)
    skipped = dropwhile(lambda x: x != 0, cycled)
    sliced = islice(skipped, None, len(best_found_route))
    best_found_rute = list(sliced)

    return best_found_route

def three_opt_rtdl(graph, route = None, logger = None):

    moves_cost = {OptCase.opt_case_1: 0, OptCase.opt_case_2: 0,
                  OptCase.opt_case_3: 0, OptCase.opt_case_4: 0, OptCase.opt_case_5: 0,
                  OptCase.opt_case_6: 0, OptCase.opt_case_7: 0, OptCase.opt_case_8: 0}

    improved = True
    best_found_route = route
    iterations = 0
    attempts_cnt = 0
    best_distance = calculate_total_distance(best_found_route, graph)

    edge_sel = EdgeSelector(torch.tensor(graph))
    #triples = list(possible_segments(len(graph)))
    #v_triples = get_vars2(best_found_route, triples)

    epoch_freq = max(1, len(best_found_route) // 10)

    def triples_iterator(vlist, v_triples, triples):
        checked = np.zeros(len(triples), dtype = np.int32)

        for v in vlist:
            for idx in v_triples[v]:
                if checked[idx]:
                    continue
                else:
                    checked[idx] = 1
                    yield triples[idx]

    while improved:
        improved = False

        if iterations % epoch_freq == 0:
            print('calc rtdl')
            edge_sel.get_rtdl_weights(best_found_route)

        vlist = edge_sel.get_sorted_edges(best_found_route)
        #print(vlist)

        #for (i, j, k) in triples_iterator(vlist, v_triples, triples):
        for (i, j, k ) in ranked_iter(vlist, len(route)):
            # we check all the possible moves and save the result into the dict
            for opt_case in OptCase:
                moves_cost[opt_case] = get_solution_cost_change(graph, best_found_route, opt_case, i, j, k)
                attempts_cnt += 1
            # we need the minimum value of substraction of old route - new route
            best_return = max(moves_cost, key=moves_cost.get)

            if moves_cost[best_return] > 0:
                best_found_route = reverse_segments(best_found_route, best_return, i, j, k)
                improved = True
                iterations += 1
                best_distance -= moves_cost[best_return]
                print('improved, iterations', iterations, best_distance)
                logger.iteration_finished(best_distance, attempts_cnt)
                attempts_cnt = 0
                break

    print(attempts_cnt)
    logger.iteration_finished(best_distance, attempts_cnt)

    # just to start with the same node -> we will need to cycle the results.
    cycled = cycle(best_found_route)
    skipped = dropwhile(lambda x: x != 0, cycled)
    sliced = islice(skipped, None, len(best_found_route))
    best_found_rute = list(sliced)

    return best_found_route
