import sys
sys.path.append('../../')

import pickle
import cppimport.import_hook
from dld.src.cpp.cvrp import SISRs
import numpy as np
import time
import torch
import argparse
import logging
import os
import csv

def use_pkl_saved_problems(filename):
    with open(filename, 'rb') as pickle_file:
        data = pickle.load(pickle_file)

    dataset_size = len(data)

    depot_data = [data[i][0] for i in range(dataset_size)]
    depot_data = np.array(depot_data)[:, None, :]
    # shape: (batch, 1, 2)

    node_data = [data[i][1] for i in range(dataset_size)]
    node_data = np.array(node_data)
    # shape: (batch, problem, 2)

    demand_data = [data[i][2] for i in range(dataset_size)]
    demand_data = np.array(demand_data)
    capacity = [data[i][3] for i in range(dataset_size)]

    # Check that all elements in capacity are integers or can be converted to integers
    assert all(isinstance(c, int) or isinstance(c, float) and c.is_integer() for c in
               capacity), "All capacity values must be integers."

    # Convert all elements to integers
    capacity = [int(c) for c in capacity]

    depot_node_xy = np.concatenate((depot_data, node_data), axis=1)
    # shape: (batch, problem+1, 2)
    depot_demand = np.zeros(shape=(dataset_size, 1), dtype="int")
    # shape: (batch, 1)
    depot_node_demand = np.concatenate((depot_demand, demand_data), axis=1)
    # shape: (batch, problem+1)

    return capacity, depot_node_xy, depot_node_demand, None


def use_saved_problems(filename):
    loaded_dict = torch.load(filename, map_location='cpu')
    saved_depot_xy = loaded_dict['depot_xy'].numpy()
    saved_node_xy = loaded_dict['node_xy'].numpy()
    saved_node_demand = loaded_dict['node_demand'].numpy()
    capacity = loaded_dict['capacity'].numpy()
    grid_size = loaded_dict['grid_size']

    depot_node_xy = np.concatenate((saved_depot_xy, saved_node_xy), axis=1)

    depot_demand = np.zeros(shape=(saved_node_xy.shape[0], 1), dtype="int")
    depot_node_demand = np.concatenate((depot_demand, saved_node_demand), axis=1)

    return capacity, depot_node_xy, depot_node_demand, grid_size


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("dataset_path", help="Filename of the dataset(s) to evaluate")
    parser.add_argument('--results_dir', default='results', help="Name of results directory")
    parser.add_argument('--timelimit', type=int, default=None, help="Timelimit per instance in seconds")
    parser.add_argument('--iterations', type=int, default=None, help="Number of iterations per instance (ignores timelimit if set)")
    parser.add_argument('--nb_instances', type=int, default=-1, help="Number of instances to process")


    opts = parser.parse_args()

    assert opts.timelimit is None or opts.iterations is None, "Search can not be limited by timelimit AND nb. of iterations."

    dataset_basename, ext = os.path.splitext(os.path.split(opts.dataset_path)[-1])

    if opts.dataset_path.endswith(".pkl"):
        capacity, depot_node_xy, depot_node_demand, grid_size = use_pkl_saved_problems(opts.dataset_path)
        start_temp = 0.1
        end_temp = 0.001
    elif opts.dataset_path.endswith('.pt'):
        capacity, depot_node_xy, depot_node_demand, grid_size = use_saved_problems(opts.dataset_path)
        assert grid_size == 1
        start_temp = 0.1
        end_temp = 0.001
    else:
        raise NotImplementedError

    default_minProblemSize = 100
    default_maxProblemSize = 1000
    default_itMin = 3e7
    default_itMax = 3e8

    # If not timelimit or iteration limit is given, calculate the iteration limit as in the SISRs paper
    if opts.iterations is None and opts.timelimit is None:
        problem_size = depot_node_xy.shape[1]
        iterations = int(default_itMin + (default_itMax - default_itMin) * (
                    (problem_size - default_minProblemSize) / (default_maxProblemSize - default_minProblemSize)))
        t = 0
        results_dir = os.path.join(opts.results_dir, dataset_basename + '_' + str(iterations) + 'iter')
    elif opts.iterations is not None:
        iterations = opts.iterations
        t = 0
        results_dir = os.path.join(opts.results_dir, dataset_basename + '_' + str(iterations) + 'iter')
    else:
        iterations = 0
        t = opts.timelimit
        results_dir = os.path.join(opts.results_dir, dataset_basename + '_' + str(t) + 's')

    nb_instances = opts.nb_instances
    if nb_instances < 0:
        nb_instances = depot_node_xy.shape[0]

    # Logging
    os.makedirs(results_dir, exist_ok=True)
    # Define the path to the log file
    log_file = os.path.join(results_dir, 'log.txt')

    # Configure the logging
    logging.basicConfig(
        level=logging.INFO,  # Set the logging level
        format='%(asctime)s - %(message)s',  # Log format
        handlers=[
            logging.FileHandler(log_file),  # Write logs to a file
            logging.StreamHandler()  # Optional: Also output logs to the console
        ]
    )

    dataset_costs = []
    dataset_rt = []
    dataset_iter = []
    for i in range(nb_instances):
        instance = SISRs.Instance(depot_node_xy.shape[1] - 1, capacity[i], depot_node_demand[i], depot_node_xy[i])
        t_start = time.time()
        solution, _, nb_iter = SISRs.search(instance, iterations, t, start_temp, end_temp)
        runtime = time.time() - t_start
        cost = solution.totalCosts

        dataset_costs.append(cost)
        dataset_rt.append(runtime)
        dataset_iter.append(nb_iter)

        solution_routes = solution.getTourList()
        solution_routes = [[0, *r, 0] for r in solution_routes]

        logging.info(
            "Instance {:3d}/{:3d}, score: {:.3f}, running_mean: {:.3f} iter: {}".format(
                i + 1, nb_instances, cost, np.array(dataset_costs).mean(), nb_iter,
            ))

        with open(os.path.join(results_dir, "results.csv"), mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([i + 1, cost, runtime, nb_iter])

        with open(os.path.join(results_dir, "solutions.csv"), mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([i + 1, solution_routes])

    logging.info(" *** Test Done *** ")
    logging.info(" AVG. COSTS: {:.4f} ".format(np.array(dataset_costs).mean()))
    logging.info(" AVG. RUNTIME: {:.2f} ".format(np.array(dataset_rt).mean()))
    logging.info(" AVG. ITERATIONS: {:.0f} ".format(np.array(dataset_iter).mean()))

