import json
import time
import os
import toml
import pathlib

import numpy as np
import gymnasium as gym

from gymnasium.spaces import Box
from deap import base, creator, tools
from gymnasium.spaces import GraphInstance

from utils import compute_hypervolume
from routing.genetic_algorithm.operators import *
from routing.helper_functions import *


def get_fitness_bounds(fitnesses):
    fitnesses = np.array(fitnesses)

    # Calculate min and max bounds
    min_bounds = np.min(fitnesses, axis=0)
    max_bounds = np.max(fitnesses, axis=0)

    # Calculate bounds
    bounds = np.array([min_bounds, max_bounds])
    return bounds


def update_bounds(bounds, fitness):
    fitness = np.array(fitness)

    # Calculate min and max bounds for each objective separately
    min_bounds = np.min(fitness, axis=0)
    max_bounds = np.max(fitness, axis=0)

    # Update bounds for each objective separately
    for i in range(len(min_bounds)):
        bounds[0][i] = np.minimum(bounds[0][i], min_bounds[i])
        bounds[1][i] = np.maximum(bounds[1][i], max_bounds[i])

    return bounds


def get_edge_links(self):
    edge_links = [[], []]
    for i in range(self.population_size):
        for j in range(self.population_size):
            # if i != j:
            edge_links[0].append(i)
            edge_links[1].append(j)
    return np.array(edge_links).astype(np.int64)


def get_edges(self):
    # Perform non-dominated sorting on the population
    fronts = tools.sortNondominated(self.population, self.population_size, first_front_only=False)

    # Create a mapping from individual to its front
    individual_to_front = {i: [] for i in range(len(fronts))}
    for front_index, front in enumerate(fronts):
        for individual in front:
            individual_index = self.population.index(individual)
            individual_to_front[front_index].append(individual_index)

    # Initialize the adjacency matrix
    edges = []
    for i in range(self.population_size):
        for j in range(self.population_size):
            edges.append(False)

    # Connect individuals within the same Pareto front
    for front, individuals in individual_to_front.items():
        for i in individuals:
            for j in individuals:
                if i != j:
                    edges[i * self.population_size + j] = True
    return np.array(edges).astype(np.bool_)


def normalize_fitnesses(population, bounds):
    normalized_fitnesses = []
    for ind in population:
        normalized_values = []
        for obj, fit in enumerate(ind.fitness.values):
            if bounds[0][obj] > bounds[1][obj]:
                raise ValueError('bound min > bound max', bounds[0][obj], bounds[1][obj], bounds)
            elif bounds[0][obj] == bounds[1][obj]:
                normalized_values.append(0.5)
            else:
                normalized_values.append((fit - bounds[0][obj]) / (bounds[1][obj] - bounds[0][obj]))
        normalized_fitnesses.append(normalized_values)
    return normalized_fitnesses


class routingEnv(gym.Env):
    def __init__(self, parameters):
        self.population_size = parameters['environment']['population_size']
        self.nr_objectives = parameters['environment']['nr_objectives']
        self.max_generations = parameters['environment']['max_generations']

        self.solution_spaces = []
        self.solution_space = Box(low=0, high=1, shape=(self.population_size, self.nr_objectives),
                                  dtype=np.float32)

        self.action_space = Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        self.reference_point = None
        self.generation = 0
        self.done = False
        self.reward_factor = parameters['environment']['reward_factor']
        self.instance_file = parameters['environment']['instance_file']
        self.problem_instances = parameters['environment']['problem_instances']

        if 'cvrp_20_' in parameters['environment']['instance_file']:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_20_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        elif 'cvrp_50_' in parameters['environment']['instance_file']:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_50_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        elif 'cvrp_100_' in parameters['environment']['instance_file']:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_100_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        elif 'cvrp_200_' in parameters['environment']['instance_file']:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_200_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        else:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_500_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        self.save_results = False
        if 'results_saving' in parameters:
            self.save_results = parameters['results_saving']['save_result']
            self.folder = parameters['results_saving']['folder']
            self.exp_name = parameters['results_saving']['exp_name']

        self.bounds = []

    def _observe(self):
        obs = {'graph': GraphInstance(
            #nodes=np.array([list(ind.fitness.values) for ind in self.population]),
            nodes=np.array([fit for fit in normalize_fitnesses(self.population, self.bounds)]),
            # edges=np.ones((self.population_size, self.population_size), dtype=np.float32),
            # edges=np.ones((self.population_size, self.population_size), dtype=np.bool_),
            edges=get_edges(self),
            edge_links=get_edge_links(self)),
            'additional_features': np.array([self.generation / self.max_generations])}
        return obs

    def reset(self):
        self.problem_instance = random.choice(self.problem_instances)
        nb_customers, truck_capacity, dist_matrix_data, dist_depot_data, demands_data = read_input_cvrp(self.instance_file, self.problem_instance)
        self.toolbox = base.Toolbox()

        creator.create("FitnessMin", base.Fitness, weights=(-1.0, -1.0))
        creator.create("Individual", list, fitness=creator.FitnessMin)

        self.toolbox.register('indexes', random.sample, range(1, nb_customers + 1), nb_customers)
        self.toolbox.register('individual', tools.initIterate, creator.Individual, self.toolbox.indexes)
        self.toolbox.register('population', tools.initRepeat, list, self.toolbox.individual)

        self.toolbox.register("mate", ordered_crossover)
        self.toolbox.register("mutate", mutation_shuffle)
        self.toolbox.register("select", tools.selNSGA2)
        self.toolbox.register('evaluate', eval_individual_fitness, truck_capacity=truck_capacity,
                         dist_matrix_data=dist_matrix_data, dist_depot_data=dist_depot_data, demands_data=demands_data)

        self.population = self.toolbox.population(self.population_size)

        fitnesses = [list(self.toolbox.evaluate(i)) for i in self.population]
        for ind, fit in zip(self.population, fitnesses):
            ind.fitness.values = tuple(fit)

        if self.save_results:
            self.hof = tools.ParetoFront()
            self.hof.update(self.population)

        self.bounds = get_fitness_bounds([ind.fitness.values for ind in self.population])
        if not self.save_results:
            self.reference_point = get_fitness_bounds([ind.fitness.values for ind in self.population])[1]
        else:
            if 'cvrp_20_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_20_2_obj.json"
            elif 'cvrp_50_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_50_2_obj.json"
            elif 'cvrp_100_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_100_2_obj.json"
            elif 'cvrp_200_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_200_2_obj.json"
            elif 'cvrp_500_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_500_2_obj.json"

            if os.path.isfile(REFERENCE_POINTS_FILE):
                with open(REFERENCE_POINTS_FILE, 'r') as file:
                    reference_points = json.load(file)
                    if str(self.problem_instance) in reference_points:
                        print('using reference point from file')
                        self.reference_point = reference_points[str(self.problem_instance)]
                        print('ref_point:', self.reference_point)
                    else:
                        print('NO REFERENCE POINT KNOWN')

        self.generation = 0
        self.done = False

        self.solution_space = np.array([ind.fitness.values for ind in self.population]).astype(np.float32)
        self.solution_spaces.append(self.solution_space)
        self.initial_hv = self.step_hv = compute_hypervolume(self.population, self.nr_objectives, self.reference_point)
        if not self.save_results:
            self.ideal_hv = compute_hypervolume([tuple(self.ideal_points[str(self.problem_instance)])],
                                                self.nr_objectives, self.reference_point)
        self.best_hv = self.initial_hv

        # Return the initial state
        return self._observe(), dict()

    def step(self, action):
        action1 = np.nan_to_num(action[0])
        action1 = np.clip(action1, -1, 1)
        action2 = np.nan_to_num(action[1])
        action2 = np.clip(action2, -1, 1)

        reward = 0
        self.generation += 1
        cxpb = (action1+1) * 0.1 + 0.5  # (between 0.5 and 0.7)
        mutpb = (action2+1) * 0.1  # (between 0 and 0.2)

        offspring = []
        for _ in range(self.population_size):
            if random.random() <= cxpb:
                ind1, ind2 = list(map(self.toolbox.clone, random.sample(self.population, 2)))
                self.toolbox.mate(ind1, ind2)
                del ind1.fitness.values, ind2.fitness.values

            else:
                ind1 = self.toolbox.clone(random.choice(self.population))

            self.toolbox.mutate(ind1, mutpb)
            del ind1.fitness.values
            offspring.append(ind1)

        fitnesses = [list(self.toolbox.evaluate(i)) for i in offspring]
        for ind, fit in zip(offspring, fitnesses):
            ind.fitness.values = tuple(fit)

        if self.save_results:
            self.hof.update(offspring)

        self.bounds = update_bounds(self.bounds, fitnesses)

        # Select next generation population
        self.population = self.toolbox.select(self.population + offspring, self.population_size)
        self.solution_space = np.array([list(ind.fitness.values) for ind in self.population]).astype(np.float32)
        self.solution_spaces.append(self.solution_space)

        if not self.save_results:
            episode_hv = compute_hypervolume(self.population, self.nr_objectives, self.reference_point)
            # reward = ((episode_hv - self.initial_hv) / (self.ideal_hv - self.initial_hv)) * 100

            if self.best_hv < episode_hv:
                current_gap = (episode_hv - self.initial_hv) / (self.ideal_hv - self.initial_hv) * 100
                previous_gap = (self.best_hv - self.initial_hv) / (self.ideal_hv - self.initial_hv) * 100
                # print(current_gap, previous_gap)
                reward = round(((self.reward_factor * current_gap) ** 2) - ((self.reward_factor * previous_gap) ** 2), 1)
                # reward = round(((((self.reward_factor * current_gap) ** 2) - ((self.reward_factor * previous_gap) ** 2))/10), 1)
                # print(self.generation, reward)
                # print(self.generation, reward)
                self.best_hv = episode_hv

        if self.generation == self.max_generations:
            self.done = True

            if self.save_results:
                print('saving results')
                self.save_result()

        # Return new state, reward, done, and optional info
        return self._observe(), reward, self.done, None, {}

    # def render(self, mode='human'):
    #     # Render the environment
    #     if mode == 'human':
    #         # Implement visualization for human consumption
    #         pass
    #     else:
    #         raise NotImplementedError("Render mode not supported: {}".format(mode))
    #
    # def close(self):
    #     # Perform any necessary cleanup
    #     pass
    #
    # --------------------------------------------------------------------------------------------------------------------
    def sample(self):
        """
        Sample random actions and run the environment
        """
        for episode in range(5):
            start_time = time.time()

            print("start episode: ", episode)
            _, _ = self.reset()
            while True:
                action = env.action_space.sample()  # Take random action
                graph, reward, done, _, _ = env.step(action)
                if done:
                    end_time = time.time()  # End time for the episode
                    duration = end_time - start_time  # Calculate the duration
                    print(f"Episode {episode} completed in {duration:.2f} seconds")
                    break

    def save_result(self):
        output_dir = self.folder
        pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

        # Ensure that exp_name includes a slash (/) if needed
        exp_name = self.exp_name.strip("/")

        results = {}
        results['problem_instance'] = self.problem_instance
        results['hypervolume'] = compute_hypervolume(self.hof, self.nr_objectives, self.reference_point)
        print('hv:', results['hypervolume'], 'ref_point:', self.reference_point, len(self.hof))

        results_csv_path = os.path.join(output_dir, f'{exp_name}_results.csv')
        df = pd.DataFrame.from_dict(results, orient='index').T
        pd.DataFrame(df).to_csv(results_csv_path, index=False)
        print('results saved :-)')

    # def plot(self):
    #     # Create a color map
    #     colors = cm.rainbow(np.linspace(0, 1, len(self.solution_spaces)))
    #
    #     plt.figure(figsize=(10, 8))
    #
    #     # Plot each array with its corresponding color
    #     for idx, array in enumerate(self.solution_spaces):
    #         plt.scatter(array[:, 0], array[:, 1], color=colors[idx], label=f"Timestep {idx}")
    #
    #     # print(min([i for i, j in self.solution_spaces[-1]]), min([j for i, j in self.solution_spaces[-1]]))
    #
    #     plt.xlabel("Objective 1")
    #     plt.ylabel("Objective 2")
    #     plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1))
    #     plt.show()


# --------------------------------------------------------------------------------------------------------------------
# import toml
# import matplotlib.pyplot as plt
# import matplotlib.cm as cm
#
# if __name__ == "__main__":
#     config_filepath = 'C:/Users/s143036/PycharmProjects/GNN4APC/v9/config_DTLZ.toml'
#     with open(config_filepath, 'r') as toml_file:
#         parameters = toml.load(toml_file)
#     env = MoopEnv(parameters)
#     env.sample()

# --------------------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
    config_filepath = "configs/config_routing_100.toml"
    with open(config_filepath, 'r') as toml_file:
        parameters = toml.load(toml_file)
    env = routingEnv(parameters)
    env.sample()