import toml
import json
import pathlib
import os
import pandas as pd

from deap import base, creator, tools
import gymnasium as gym
from gymnasium import spaces
from gymnasium.spaces import GraphInstance

from utils import compute_hypervolume
from scheduling.genetic_algorithm.operators import *

from scheduling.helper_functions import load_job_shop_env
from scheduling.genetic_algorithm.operators import (evaluate_individual, variation,
                                                    init_individual, init_population, mutate_shortest_proc_time,
                                                    mutate_sequence_exchange, pox_crossover)

REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/scheduling/data/reference_points.json"


def get_fitness_bounds(fitnesses):
    fitnesses = np.array(fitnesses)

    # Calculate min and max bounds
    min_bounds = np.min(fitnesses, axis=0)
    max_bounds = np.max(fitnesses, axis=0)

    # Calculate bounds
    bounds = np.array([min_bounds, max_bounds])
    return bounds


def update_bounds(bounds, fitness):
    fitness = np.array(fitness)

    # Calculate min and max bounds for each objective separately
    min_bounds = np.min(fitness, axis=0)
    max_bounds = np.max(fitness, axis=0)

    # Update bounds for each objective separately
    for i in range(len(min_bounds)):
        bounds[0][i] = np.minimum(bounds[0][i], min_bounds[i])
        bounds[1][i] = np.maximum(bounds[1][i], max_bounds[i])

    return bounds


def get_edge_links(self):
    edge_links = [[], []]
    for i in range(self.population_size):
        for j in range(self.population_size):
            # if i != j:
            edge_links[0].append(i)
            edge_links[1].append(j)
    return np.array(edge_links).astype(np.int64)


def get_edges(self):
    # Perform non-dominated sorting on the population
    fronts = tools.sortNondominated(self.population, self.population_size, first_front_only=False)

    # Create a mapping from individual to its front
    individual_to_front = {i: [] for i in range(len(fronts))}
    for front_index, front in enumerate(fronts):
        for individual in front:
            individual_index = self.population.index(individual)
            individual_to_front[front_index].append(individual_index)

    # Initialize the adjacency matrix
    edges = []
    for i in range(self.population_size):
        for j in range(self.population_size):
            edges.append(False)

    # Connect individuals within the same Pareto front
    for front, individuals in individual_to_front.items():
        for i in individuals:
            for j in individuals:
                if i != j:
                    edges[i * self.population_size + j] = True
    return np.array(edges).astype(np.bool_)


def normalize_fitnesses(population, bounds):
    normalized_fitnesses = []
    for ind in population:
        normalized_values = []
        for obj, fit in enumerate(ind.fitness.values):
            if bounds[0][obj] > bounds[1][obj]:
                raise ValueError('bound min > bound max', bounds[0][obj], bounds[1][obj], bounds)
            elif bounds[0][obj] == bounds[1][obj]:
                normalized_values.append(0.5)
            else:
                normalized_values.append((fit - bounds[0][obj]) / (bounds[1][obj] - bounds[0][obj]))
        normalized_fitnesses.append(normalized_values)
    return normalized_fitnesses


class schedulingEnv(gym.Env):
    def __init__(self, parameters):
        self.population_size = parameters['environment']['population_size']
        self.nr_objectives = parameters['environment']['nr_objectives']
        self.max_generations = parameters['environment']['max_generations']
        self.alternative_objectives = parameters['environment']['alternative_objectives']
        self.jobShopEnv = None

        self.solution_spaces = []
        self.solution_space = spaces.Box(low=0, high=1, shape=(self.population_size, self.nr_objectives),
                                         dtype=np.float32)
        self.adjacency_matrix = spaces.Box(low=0, high=1, shape=(self.population_size, self.population_size),
                                           dtype=np.int32)

        # Action space: a continuous action for each node with values between -1 and 1
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        self.reference_point = None
        self.generation = 0
        self.done = False
        self.reward_factor = parameters['environment']['reward_factor']
        self.problem_instances = parameters['environment']['problem_instances']

        with open('/hpc/za64617/projects/GNN4APC_dev/code/src/scheduling/data/ideal_points_{}_obj.json'.format(self.nr_objectives), 'r') as json_file:
            self.ideal_points = json.load(json_file)

        self.save_results = False
        if 'results_saving' in parameters:
            self.save_results = parameters['results_saving']['save_result']
            self.folder = parameters['results_saving']['folder']
            self.exp_name = parameters['results_saving']['exp_name']

        self.bounds = []

    def _observe(self):
        obs = {'graph': GraphInstance(
            #nodes=np.array([list(ind.fitness.values) for ind in self.population]),
            nodes=np.array([fit for fit in normalize_fitnesses(self.population, self.bounds)]),
            # edges=np.ones((self.population_size, self.population_size), dtype=np.float32),
            # edges=np.ones((self.population_size, self.population_size), dtype=np.bool_),
            edges=get_edges(self),
            edge_links=get_edge_links(self)),
            'additional_features': np.array([self.generation / self.max_generations])}
        return obs

    def reset(self):
        self.problem_instance = random.choice(self.problem_instances)
        self.jobShopEnv = load_job_shop_env(self.problem_instance)
        self.toolbox = base.Toolbox()

        creator.create("Fitness", base.Fitness, weights=tuple([-1.0 for i in range(self.nr_objectives)]))
        creator.create("Individual", list, fitness=creator.Fitness)

        self.toolbox.register("init_individual", init_individual, creator.Individual, None, jobShopEnv=self.jobShopEnv)
        self.toolbox.register("mate_TwoPoint", tools.cxTwoPoint)
        self.toolbox.register("mate_Uniform", tools.cxUniform, indpb=0.5)
        self.toolbox.register("mate_POX", pox_crossover, nr_preserving_jobs=1)

        self.toolbox.register("mutate_machine_selection", mutate_shortest_proc_time, jobShopEnv=self.jobShopEnv)
        self.toolbox.register("mutate_operation_sequence", mutate_sequence_exchange)
        self.toolbox.register("select", tools.selNSGA2)
        self.toolbox.register("evaluate_individual", evaluate_individual, jobShopEnv=self.jobShopEnv,
                              alt_objectives=self.alternative_objectives, objectives=self.nr_objectives)

        self.population = init_population(self.toolbox, self.population_size, )

        individuals = [[ind[0], ind[1]] for ind in self.population]
        fitnesses = [self.toolbox.evaluate_individual(ind) for ind in individuals]
        fitnesses = [fit[0] for fit in fitnesses]

        for ind, fit in zip(self.population, fitnesses):
            ind.fitness.values = fit

        if self.save_results:
            self.hof = tools.ParetoFront()
            self.hof.update(self.population)

        self.bounds = get_fitness_bounds([ind.fitness.values for ind in self.population])
        if not self.save_results:
            self.reference_point = get_fitness_bounds([ind.fitness.values for ind in self.population])[1]
        else:
            if os.path.isfile(REFERENCE_POINTS_FILE):
                with open(REFERENCE_POINTS_FILE, 'r') as file:
                    reference_points = json.load(file)
                    if self.problem_instance in reference_points:
                        if not self.alternative_objectives:
                            self.reference_point = reference_points[self.problem_instance][0:self.nr_objectives]
                        else:
                            print('USING ALTERNATIVE OBJECTIVES (2)')
                            self.reference_point = reference_points[self.problem_instance][-2:]

                        print('using reference point from file', self.reference_point)
                    else:
                        print('NO REFERENCE POINT KNOWN')

        self.generation = 0
        self.done = False

        self.solution_space = np.array([ind.fitness.values for ind in self.population]).astype(np.float32)
        self.solution_spaces.append(self.solution_space)
        self.initial_hv = self.step_hv = compute_hypervolume(self.population, self.nr_objectives, self.reference_point)
        if not self.save_results:
            self.ideal_hv = compute_hypervolume([tuple(self.ideal_points[self.problem_instance][:self.nr_objectives])],
                                                self.nr_objectives, self.reference_point)
        self.best_hv = self.initial_hv

        # Return the initial state
        return self._observe(), dict()

    def step(self, action):
        action1 = np.nan_to_num(action[0])
        action1 = np.clip(action1, -1, 1)
        action2 = np.nan_to_num(action[1])
        action2 = np.clip(action2, -1, 1)

        reward = 0
        self.generation += 1
        cxpb = (action1 + 1) * 0.2 + 0.6  # (between 0.6 and 1)
        mutpb = (action2 + 1) * 0.05  # (between 0 and 0.1)

        offspring = variation(self.population, self.toolbox, self.population_size, cxpb, mutpb)

        if '/dafjs/' or '/yfjs/' in jobShopEnv.instance_name:
            offspring = repair_precedence_constraints(self.jobShopEnv, offspring)

        # Evaluate the population
        # sequential evaluation of population
        individuals = [[ind[0], ind[1]] for ind in offspring]
        fitnesses = [self.toolbox.evaluate_individual(ind) for ind in individuals]
        fitnesses = [fit[0] for fit in fitnesses]
        # fitnesses = evaluate_population(self.toolbox, offspring, self.nr_objectives, False)
        for ind, fit in zip(offspring, fitnesses):
            ind.fitness.values = fit

        if self.save_results:
            self.hof.update(offspring)

        self.bounds = update_bounds(self.bounds, fitnesses)

        # Select next generation population
        self.population = self.toolbox.select(self.population + offspring, self.population_size)
        self.solution_space = np.array([list(ind.fitness.values) for ind in self.population]).astype(np.float32)
        self.solution_spaces.append(self.solution_space)

        if not self.save_results:
            episode_hv = compute_hypervolume(self.population, self.nr_objectives, self.reference_point)
            # reward = ((episode_hv - self.initial_hv) / (self.ideal_hv - self.initial_hv)) * 100

            if self.best_hv < episode_hv:
                current_gap = (episode_hv - self.initial_hv) / (self.ideal_hv - self.initial_hv) * 100
                previous_gap = (self.best_hv - self.initial_hv) / (self.ideal_hv - self.initial_hv) * 100
                # print(current_gap, previous_gap)
                reward = round(((self.reward_factor * current_gap) ** 2) - ((self.reward_factor * previous_gap) ** 2), 1)
                # reward = round(((((self.reward_factor * current_gap) ** 2) - ((self.reward_factor * previous_gap) ** 2))/10), 1)
                # print(self.generation, reward)
                # print(self.generation, reward)
                self.best_hv = episode_hv

        if self.generation == self.max_generations:
            self.done = True

            if self.save_results:
                self.save_result()

        # Return new state, reward, done, and optional info
        return self._observe(), reward, self.done, None, {}

    # --------------------------------------------------------------------------------------------------------------------
    def sample(self):
        """
        Sample random actions and run the environment
        """
        for episode in range(5):
            start_time = time.time()

            print("start episode: ", episode)
            _, _ = self.reset()
            while True:
                action = env.action_space.sample()  # Take random action
                graph, reward, done, _, _ = env.step(action)
                if done:
                    end_time = time.time()  # End time for the episode
                    duration = end_time - start_time  # Calculate the duration
                    print(f"Episode {episode} completed in {duration:.2f} seconds")
                    break

    def save_result(self):
        output_dir = self.folder
        pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

        # Ensure that exp_name includes a slash (/) if needed
        exp_name = self.exp_name.strip("/")

        results = {}
        results['problem_instance'] = self.problem_instance
        results['hypervolume'] = compute_hypervolume(self.hof, self.nr_objectives, self.reference_point)

        results_csv_path = os.path.join(output_dir, f'{exp_name}_results.csv')
        df = pd.DataFrame.from_dict(results, orient='index').T
        pd.DataFrame(df).to_csv(results_csv_path, index=False)

    # def plot(self):
    #     # Create a color map
    #     colors = cm.rainbow(np.linspace(0, 1, len(self.solution_spaces)))
    #
    #     plt.figure(figsize=(10, 8))
    #
    #     # Plot each array with its corresponding color
    #     for idx, array in enumerate(self.solution_spaces):
    #         plt.scatter(array[:, 0], array[:, 1], color=colors[idx], label=f"Timestep {idx}")
    #
    #     plt.xlabel("Objective 1")
    #     plt.ylabel("Objective 2")
    #     plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1))
    #     plt.show()


# --------------------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
    config_filepath = "configs/config_scheduling_5_5.toml"
    with open(config_filepath, 'r') as toml_file:
        parameters = toml.load(toml_file)
    env = schedulingEnv(parameters)
    env.sample()
