import time
import torch
import numpy as np
import scipy
import matplotlib.pyplot as plt
from tqdm import trange
from random import shuffle
from copy import copy
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.sparse import coo_matrix
from rtdl_fast import RTD_Lite_TSP

def extend(alist, MAX_ITER, value = None):
    if len(alist):
        if value is None:
            value = alist[-1]
        return alist + [value] * (MAX_ITER - len(alist))
    else:
        return [0] * MAX_ITER

def get_next_after_max_mst(D):
    
    mst = minimum_spanning_tree(D)
    mst_coo = coo_matrix(mst)
    max_mst_edge = np.max(mst_coo.data)
    
    for elem in sorted(D.reshape(-1)):
        if elem > max_mst_edge:
            return elem

class Logger:
    def __init__(self, MAX_ITER):
        self.sum_distances = np.zeros(MAX_ITER)
        self.sum_attempts = None
        self.time_start = time.time()
        self.MAX_ITER = MAX_ITER

    def problem_started(self):
        self.distances = []
        self.attempts = []

    def problem_solved(self):
        self.sum_distances += extend(self.distances, self.MAX_ITER)

        if self.sum_attempts is None:
            self.sum_attempts = np.array(extend(self.attempts, self.MAX_ITER, float('nan')))
        else:
            self.sum_attempts = np.vstack((self.sum_attempts, extend(self.attempts, self.MAX_ITER, float('nan'))))

        self.time_end = time.time()

    def iteration_finished(self, best_distance, attempts_cnt):
        self.attempts.append(attempts_cnt)
        self.distances.append(best_distance)

def create_D_tour(N, D1, tour):
    D_tour = np.full((N, N), float('inf'))
    
    for i in range(N):
        k1, k2 = tour[i-1], tour[i]
        D_tour[k1, k2] = D1[k1, k2]
        D_tour[k2, k1] = D1[k1, k2]

    return D_tour

def get_edge_penalty(res, D_tour, D1):
    edge_penalty = {}
    total_penalty = 0.0

    for birth, death in res[1]['1->2']:
        penalty = (D_tour[death[0], death[1]] - D1[birth[0], birth[1]]).item()
        edge_penalty[(death[0], death[1])] = penalty

        total_penalty += penalty

    return edge_penalty, total_penalty

def get_point(edge2point, edge):
    p1 = edge2point.get(edge)

    if p1 is None:
        p1 = edge2point.get((edge[1], edge[0]))

    return p1

class EdgeSelector:
    def __init__(self, D1):
        self.D1 = D1
        self.rtdl_obj = RTD_Lite_TSP(None, self.D1, cache_r2_min = True)
        #self.next_mst_edge = get_next_after_max_mst(D1)

    def get_rtdl_weights(self, tour):
        N = self.D1.shape[0]
        
        self.tour_edges = []
        
        for i in range(N):
            k1, k2 = tour[i-1], tour[i]
            self.tour_edges.append(((k1, k2), self.D1[k1, k2].item(), i))

        max_edge = max(self.tour_edges, key = lambda x : x[1])
        tour_wo_max = self.tour_edges[:max_edge[2]] + self.tour_edges[max_edge[2]+1:]
        self.rtdl_obj.r1_edge_idx = np.array([x[0] for x in tour_wo_max])
        self.rtdl_obj.r1_edge_w = torch.tensor([x[1] for x in tour_wo_max], dtype=torch.float64)

        D_tour = create_D_tour(N, self.D1, tour)
        self.rtdl_obj.r1 = torch.tensor(D_tour)
        res = self.rtdl_obj()

        self.edge_penalty, _ = get_edge_penalty(res, D_tour, self.D1)
        #self.edge_penalty[max_edge[0]] = max_edge[1]
        self.edge_penalty[max_edge[0]] = np.min([x[1] for x in self.edge_penalty.items()]) # optimal variant
        #self.edge_penalty[max_edge[0]] = max_edge[1] - self.next_mst_edge

        self.edges_sorted = [x[0] for x in sorted(self.edge_penalty.items(), key = lambda x : -x[1])]

    def get_sorted_edges(self, tour):
        
        edge2point = {}
    
        for i in range(len(tour)):
            edge = (tour[i-1], tour[i])
            edge2point[edge] = i
            
        plist = []
      
        for e in self.edges_sorted:
            p = get_point(edge2point, e) 
            if not (p is None):
                plist.append(p)
                
        new_v = list(set(range(len(tour))) - set(plist))
        #print('no info vertices', len(new_v))
        plist = plist + new_v
        
        return plist

def calculate_total_distance(tour, distance_matrix):
    """Calculate the total distance of a tour"""
    total = 0
    num_cities = len(tour)
    for i in range(num_cities):
        total += distance_matrix[tour[i-1], tour[i]]

    return total

def two_opt_swap(tour, i, k):
    """
    Perform a 2-opt swap by reversing the segment between i and k
    i < k.
    Edges (i-1, i), (k, k+1) are swaped with (i-1, k), (i, k+1)
    """
    new_tour = tour[:i] + tour[i:k+1][::-1] + tour[k+1:]

    return new_tour

def delta_two_opt_swap(d, tour, i, k):
    """
    Perform a 2-opt swap by reversing the segment between i and k
    i < k.
    Edges (i-1, i), (k, k+1) are swaped with (i-1, k), (i, k+1)
    """

    return d[tour[i-1],tour[k]] + d[tour[i],tour[k+1]] - d[tour[i-1],tour[i]] - d[tour[k],tour[k+1]]

def shuffled(x):
    y = list(x)
    shuffle(y)
    return y

def two_opt(tour, distance_matrix, max_iterations=1000, logger = None, break_loop = True):
    """Implement the 2-opt algorithm"""
    improvement = True
    best_tour = tour.copy()
    best_distance = calculate_total_distance(best_tour, distance_matrix)
    iterations = 0
    attempts_cnt = 0

    while improvement and iterations < max_iterations - 1:
        improvement = False

        #print('iter')

        for i in range(0, len(best_tour) - 1):
            for k in range(i+1, len(best_tour) - 1):
                if i == 0 and k == len(best_tour) - 2:
                   continue

                #print(i, k)

                delta = delta_two_opt_swap(distance_matrix, best_tour, i, k)
                attempts_cnt += 1

                #print(i, k, delta)

                if delta < 0:
                    best_tour = two_opt_swap(best_tour, i, k)
                    improvement = True

                    iterations += 1
                    if logger:
                        best_distance += delta
                        logger.iteration_finished(best_distance, attempts_cnt)
                    attempts_cnt = 0

                if break_loop and improvement:
                    break

            if break_loop and improvement:
                break

    print('stat', iterations, attempts_cnt)

    if logger:
        logger.iteration_finished(best_distance, attempts_cnt)

    if iterations == max_iterations:
        print('MAX_ITER REACHED')

    best_distance = calculate_total_distance(best_tour, distance_matrix)

    return best_tour, best_distance

#def vertex_iterator_rtdl(end_point, n_cities):
#    i = end_point
#
#    # 
#    # i < k. Edges (i-1,i) and (k,k+1) are removed.
#    #
#    for k in range(i+1, n_cities-1):
#        if i == 0 and k == n_cities - 2:
#            continue
#        yield (i, k)
#
#    k = end_point - 1
#
#    for i in reversed(range(0, k)):
#        if i == 0 and k == n_cities - 2:
#            continue
#        yield (i, k)

def vertex_iterator3_rtdl(end_point, n_cities, vlist):

    for v in vlist:
        m1 = min(v, end_point)
        m2 = max(v, end_point)

        if m2 - m1 <= 1:
            continue
        if m1 == 0 and m2 == n_cities - 1:
            continue

        yield (m1, m2-1)

#def vertex_iterator2_rtdl(end_point, n_cities, vlist):
#    i = end_point
#
#    # 
#    # i < k. Edges (i-1,i) and (k,k+1) are removed.
#    #
#    k_list = list(range(i+1, n_cities-1))
#    vrank = {v : i for i, v in enumerate(vlist)}
#    k_list.sort(key = lambda x : vrank[x+1])
#
#    #print(k_list)
#    #print(vrank)
#
#    for k in k_list:
#    #for k in range(i+1, n_cities-1):
#        if i == 0 and k == n_cities - 2:
#            continue
#        yield (i, k, k+1)
#
#    k = end_point - 1
#
#    i_list = list(range(0, k))
#    i_list.sort(key = lambda x : vrank[x])
#
#    #for i in reversed(range(0, k)):
#    for i in i_list:
#        if i == 0 and k == n_cities - 2:
#            continue
#        yield (i, k)

def two_opt_rtdl(tour, distance_matrix, max_iterations=1000, verbose = 0, logger = None, rtdl_period= None, random_vlist = False, break_loop = True, progressive = True):
    """Implement the 2-opt algorithm with RTDL"""
    improvement = True
    best_tour = tour.copy()
    best_distance = calculate_total_distance(best_tour, distance_matrix)
    iterations = 0
    attempts_cnt = 0
    n_cities = len(tour)

    if rtdl_period is None:
        rtdl_period = max(1, len(tour) // 10)

    edge_sel = EdgeSelector(torch.tensor(distance_matrix))

    if progressive:
        opt_len = 10 #min(10, n_cities)
    else:
        opt_len = n_cities

    while (improvement and iterations < max_iterations) or opt_len < n_cities:
        improvement = False

        if not random_vlist:
            if iterations % rtdl_period == 0:
                print('iter', iterations, 'calc rtdl')
                edge_sel.get_rtdl_weights(best_tour)

            vlist = edge_sel.get_sorted_edges(best_tour)
        else:
            vlist = list(range(n_cities))
            np.random.shuffle(vlist)

        vlist = list(filter(lambda x : x < opt_len, vlist))
        checked = np.zeros((n_cities, n_cities))

        for end_point in vlist:
            for (i, k) in vertex_iterator3_rtdl(end_point, n_cities, vlist):
                if checked[i, k]:
                    continue
                else:
                    checked[i, k] = 1

                new_delta = delta_two_opt_swap(distance_matrix, best_tour, i, k)
                attempts_cnt += 1

                if new_delta < 0:
                    improvement = True

                    if break_loop:
                        break
                    else:
                        iterations += 1
                        best_distance += new_delta
                        if logger:
                            logger.iteration_finished(best_distance, attempts_cnt)

                        best_tour = two_opt_swap(best_tour, i, k)
                        attempts_cnt = 0
 
            if break_loop and improvement:
                iterations += 1
                best_distance += new_delta
                if logger:
                    logger.iteration_finished(best_distance, attempts_cnt)

                best_tour = two_opt_swap(best_tour, i, k)
                attempts_cnt = 0
                break

        if not improvement and opt_len < n_cities:
            opt_len = min(opt_len + 10, n_cities)
            improvement = True
    
    print('stat', iterations, attempts_cnt)

    if iterations == max_iterations:
        print('MAX_ITER REACHED')

    if logger:
        logger.iteration_finished(best_distance, attempts_cnt)

    best_distance = calculate_total_distance(best_tour, distance_matrix)

    return best_tour, best_distance

def plot_tour(coordinates, tour, title="TSP Tour"):
    """Plot the TSP tour"""
    x = [coordinates[i][0] for i in tour]
    y = [coordinates[i][1] for i in tour]
    x.append(x[0])  # Return to start
    y.append(y[0])  # Return to start

    plt.figure(figsize=(8, 6))
    plt.plot(x, y, 'o-', markersize=8)
    plt.title(title)
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    for i, city in enumerate(tour):
        plt.text(coordinates[city][0], coordinates[city][1], str(i))
    plt.grid()
    plt.show()
