import torch
import numpy as np
import scipy
from rtdl import RTD_Lite
import matplotlib.pyplot as plt
from tqdm import trange

import numpy as np
import tsplib95
import scipy.spatial
import requests, tarfile, os, gzip, shutil
from tqdm.auto import tqdm
from tsplib95.loaders import load_problem, load_solution

def download_and_extract_tsplib(url, directory="tsplib", delete_after_unzip=True):
    os.makedirs(directory, exist_ok=True)
    
    # Download with progress bar
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        total_size = int(r.headers.get('content-length', 0))
        with open("tsplib.tar.gz", 'wb') as f, tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
            for chunk in r.iter_content(8192):
                f.write(chunk)
                pbar.update(len(chunk))

    # Extract tar.gz
    with tarfile.open("tsplib.tar.gz", 'r:gz') as tar:
        tar.extractall(directory)

    # Decompress .gz files inside directory
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".gz"):
                path = os.path.join(root, file)
                with gzip.open(path, 'rb') as f_in, open(path[:-3], 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)
                os.remove(path)

    if delete_after_unzip:
        os.remove("tsplib.tar.gz")

# Download and extract all tsp files under tsplib directory
# download_and_extract_tsplib("http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/ALL_tsp.tar.gz")

def create_problem_from_dict(data):
    coords = data['node_coords']
    if hasattr(coords, 'numpy'):
        cities = coords.numpy()
    else:
        cities = np.array(coords)

    # Вычисляем матрицу расстояний
    distance_matrix = scipy.spatial.distance.cdist(cities, cities)

    return cities, distance_matrix

def plot_tour(coordinates, tour, title="TSP Tour"):
    """Plot the TSP tour"""
    x = [coordinates[i][0] for i in tour]
    y = [coordinates[i][1] for i in tour]
    x.append(x[0])  # Return to start
    y.append(y[0])  # Return to start
    
    plt.figure(figsize=(8, 6))
    plt.plot(x, y, 'o-', markersize=8)
    plt.title(title)
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    for i, city in enumerate(tour):
        plt.text(coordinates[city][0], coordinates[city][1], str(i))
    plt.grid()
    plt.show()


def plot_multiple_tours(coordinates, tours_with_names):
    """
    Plot multiple TSP tours side by side.
    
    Parameters:
        coordinates: list of (x, y) coordinates
        tours_with_names: list of [tour, name] pairs
                          e.g. [[tour1, "Greedy"], [tour2, "RL"], ...]
    """
    n = len(tours_with_names)
    cols = n
    rows = 1

    fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 6))
    
    # Если один тур → axes будет не массивом, приведём к массиву
    if n == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    for ax, (tour, name) in zip(axes, tours_with_names):
        x = [coordinates[i][0] for i in tour]
        y = [coordinates[i][1] for i in tour]
        x.append(x[0])  # возврат в начало
        y.append(y[0])

        ax.plot(x, y, 'o-', markersize=8)
        ax.set_title(name)
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')
        ax.grid()

        for i, city in enumerate(tour):
            ax.text(coordinates[city][0], coordinates[city][1], str(i))

    plt.tight_layout()
    plt.show()
