import os
import torch
import numpy as np
import scipy
from rtdl import RTD_Lite
import matplotlib.pyplot as plt
from tqdm import trange
from copy import copy
import time
from tsp_rtdl_util import create_problem
from subprocess import check_call, check_output, CalledProcessError

def calculate_total_distance(tour, distance_matrix):
    """Calculate the total distance of a tour"""
    total = 0
    num_cities = len(tour)
    for i in range(num_cities):
        total += distance_matrix[tour[i-1], tour[i]]

    return total

def read_concorde_tour(filename):
    with open(filename, 'r') as f:
        n = None
        tour = []
        for line in f:
            if n is None:
                n = int(line)
            else:
                tour.extend([int(node) for node in line.rstrip().split(" ")])
    assert len(tour) == n, "Unexpected tour length"
    return tour

def write_tsplib(filename, loc, name="problem"):

    with open(filename, 'w') as f:
        f.write("\n".join([
            "{} : {}".format(k, v)
            for k, v in (
                ("NAME", name),
                ("TYPE", "TSP"),
                ("DIMENSION", len(loc)),
                ("EDGE_WEIGHT_TYPE", "EUC_2D"),
            )
        ]))
        f.write("\n")
        f.write("NODE_COORD_SECTION\n")
        f.write("\n".join([
            "{}\t{}\t{}".format(i + 1, int(x * 10000000 + 0.5), int(y * 10000000 + 0.5))  # tsplib does not take floats
            for i, (x, y) in enumerate(loc)
        ]))
        f.write("\n")
        f.write("EOF\n")

def solve_concorde_log(executable, directory, name, loc, disable_cache=False, problem_filename= None):

    if not problem_filename:
        problem_filename = os.path.join(directory, "{}.tsp".format(name))
    tour_filename = os.path.join(directory, "{}.tour".format(name))
    output_filename = os.path.join(directory, "{}.concorde.pkl".format(name))
    log_filename = os.path.join(directory, "{}.log".format(name))

    # if True:
    try:
        # May have already been run
        if os.path.isfile(output_filename) and not disable_cache:
            tour, duration = load_dataset(output_filename)
        else:
            if not problem_filename:
                write_tsplib(problem_filename, loc, name=name)

            with open(log_filename, 'w') as f:
                start = time.time()
                try:
                    # Concorde is weird, will leave traces of solution in current directory so call from target dir
                    check_call([executable, '-s', '1234', '-x', '-o',
                                os.path.abspath(tour_filename), os.path.abspath(problem_filename)],
                               stdout=f, stderr=f, cwd=directory)
                except CalledProcessError as e:
                    # Somehow Concorde returns 255
                    assert e.returncode == 255
                duration = time.time() - start

            tour = read_concorde_tour(tour_filename)
            #save_dataset((tour, duration), output_filename)

        #return calc_tsp_length(loc, tour), tour, duration
        return tour

    except Exception as e:
        print("Exception occured")
        print(e)
        return None