import json
import time
import os
import toml
import pathlib
import random

import numpy as np
import gymnasium as gym
import pandas as pd

from gymnasium.spaces import Box
from deap import base, creator, tools
from gymnasium.spaces import GraphInstance

from utils import compute_hypervolume
from scheduling.helper_functions import record_stats
from routing.helper_functions import read_input_cvrp
from routing.pso.operators import eval_individual_fitness, generate_particle, update_particle


def get_fitness_bounds(fitnesses):
    fitnesses = np.array(fitnesses)

    # Calculate min and max bounds
    min_bounds = np.min(fitnesses, axis=0)
    max_bounds = np.max(fitnesses, axis=0)

    # Calculate bounds
    bounds = np.array([min_bounds, max_bounds])
    return bounds


def update_bounds(bounds, fitness):
    fitness = np.array(fitness)

    # Calculate min and max bounds for each objective separately
    min_bounds = np.min(fitness, axis=0)
    max_bounds = np.max(fitness, axis=0)

    # Update bounds for each objective separately
    for i in range(len(min_bounds)):
        bounds[0][i] = np.minimum(bounds[0][i], min_bounds[i])
        bounds[1][i] = np.maximum(bounds[1][i], max_bounds[i])

    return bounds


def get_edge_links(self):
    edge_links = [[], []]
    for i in range(self.population_size):
        for j in range(self.population_size):
            # if i != j:
            edge_links[0].append(i)
            edge_links[1].append(j)
    return np.array(edge_links).astype(np.int64)


def get_edges(self):
    # Perform non-dominated sorting on the population
    fronts = tools.sortNondominated(self.population, self.population_size, first_front_only=False)

    # Create a mapping from individual to its front
    individual_to_front = {i: [] for i in range(len(fronts))}
    for front_index, front in enumerate(fronts):
        for individual in front:
            individual_index = self.population.index(individual)
            individual_to_front[front_index].append(individual_index)

    # Initialize the adjacency matrix
    edges = []
    for i in range(self.population_size):
        for j in range(self.population_size):
            edges.append(False)

    # Connect individuals within the same Pareto front
    for front, individuals in individual_to_front.items():
        for i in individuals:
            for j in individuals:
                if i != j:
                    edges[i * self.population_size + j] = True
    return np.array(edges).astype(np.bool_)


def normalize_fitnesses(population, bounds):
    normalized_fitnesses = []
    for ind in population:
        normalized_values = []
        for obj, fit in enumerate(ind.fitness.values):
            if bounds[0][obj] > bounds[1][obj]:
                raise ValueError('bound min > bound max', bounds[0][obj], bounds[1][obj], bounds)
            elif bounds[0][obj] == bounds[1][obj]:
                normalized_values.append(0.5)
            else:
                normalized_values.append((fit - bounds[0][obj]) / (bounds[1][obj] - bounds[0][obj]))
        normalized_fitnesses.append(normalized_values)
    return normalized_fitnesses


class routingEnvPSO(gym.Env):
    def __init__(self, parameters):
        self.population_size = parameters['environment']['population_size']
        self.nr_objectives = parameters['environment']['nr_objectives']
        self.max_generations = parameters['environment']['max_generations']

        self.solution_spaces = []
        self.solution_space = Box(low=0, high=1, shape=(self.population_size, self.nr_objectives),
                                  dtype=np.float32)

        self.action_space = Box(low=-1, high=1, shape=(3,), dtype=np.float32)
        self.reference_point = None
        self.generation = 0
        self.done = False
        self.reward_factor = parameters['environment']['reward_factor']
        self.instance_file = parameters['environment']['instance_file']
        self.problem_instances = parameters['environment']['problem_instances']

        if 'cvrp_20_' in parameters['environment']['instance_file']:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_20_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        elif 'cvrp_50_' in parameters['environment']['instance_file']:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_50_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        elif 'cvrp_100_' in parameters['environment']['instance_file']:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_100_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        elif 'cvrp_200_' in parameters['environment']['instance_file']:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_200_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        else:
            with open('/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/ideal_points_500_2_obj.json',
                      'r') as json_file:
                self.ideal_points = json.load(json_file)

        self.save_results = False
        if 'results_saving' in parameters:
            self.save_results = parameters['results_saving']['save_result']
            self.folder = parameters['results_saving']['folder']
            self.exp_name = parameters['results_saving']['exp_name']

        self.bounds = []

    def _observe(self):
        obs = {'graph': GraphInstance(
            nodes=np.array([fit for fit in normalize_fitnesses(self.population, self.bounds)]),
            edges=get_edges(self),
            edge_links=get_edge_links(self)),
            'additional_features': np.array([self.generation / self.max_generations])}
        return obs

    def reset(self):
        if hasattr(creator, "FitnessMin"):
            del creator.FitnessMin
        if hasattr(creator, "Particle"):
            del creator.Particle

        self.problem_instance = random.choice(self.problem_instances)
        nb_customers, truck_capacity, dist_matrix_data, dist_depot_data, demands_data = read_input_cvrp(self.instance_file, self.problem_instance)

        creator.create("FitnessMin", base.Fitness, weights=(-1.0, -1.0))
        creator.create("Particle", list, fitness=creator.FitnessMin, speed=list, smin=None, smax=None, best=None)

        self.toolbox = base.Toolbox()
        self.toolbox.register("particle", generate_particle, size=nb_customers, s_min=-0.1, s_max=0.1)
        self.toolbox.register("population", tools.initRepeat, list, self.toolbox.particle)
        self.toolbox.register("update", update_particle)
        self.toolbox.register('evaluate', eval_individual_fitness, truck_capacity=truck_capacity, dist_matrix_data=dist_matrix_data, dist_depot_data=dist_depot_data, demands_data=demands_data)
        self.toolbox.register("select", tools.selNSGA2)

        self.population = self.toolbox.population(self.population_size)

        fitnesses = [list(self.toolbox.evaluate(i)) for i in self.population]
        for ind, fit in zip(self.population, fitnesses):
            ind.fitness.values = tuple(fit)
            ind.best = [ind]

        self.hof = tools.ParetoFront()
        self.hof.update(self.population)

        self.bounds = get_fitness_bounds([ind.fitness.values for ind in self.population])
        if not self.save_results:
            self.reference_point = get_fitness_bounds([ind.fitness.values for ind in self.population])[1]
        else:
            if 'cvrp_20_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_20_2_obj.json"
            elif 'cvrp_50_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_50_2_obj.json"
            elif 'cvrp_100_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_100_2_obj.json"
            elif 'cvrp_200_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_200_2_obj.json"
            elif 'cvrp_500_' in self.instance_file:
                REFERENCE_POINTS_FILE = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_500_2_obj.json"

            if os.path.isfile(REFERENCE_POINTS_FILE):
                with open(REFERENCE_POINTS_FILE, 'r') as file:
                    reference_points = json.load(file)
                    if str(self.problem_instance) in reference_points:
                        print('using reference point from file')
                        self.reference_point = reference_points[str(self.problem_instance)]
                        print('ref_point:', self.reference_point)
                    else:
                        print('NO REFERENCE POINT KNOWN')

        self.generation = 0
        self.done = False

        self.solution_space = np.array([ind.fitness.values for ind in self.population]).astype(np.float32)
        self.solution_spaces.append(self.solution_space)
        self.initial_hv = self.step_hv = compute_hypervolume(self.population, self.nr_objectives, self.reference_point)
        if not self.save_results:
            self.ideal_hv = compute_hypervolume([tuple(self.ideal_points[str(self.problem_instance)])],
                                                self.nr_objectives, self.reference_point)
        self.best_hv = self.initial_hv

        # Return the initial state
        return self._observe(), dict()

    def step(self, action):
        action1 = np.nan_to_num(action[0])
        action1 = np.clip(action1, -1, 1)
        action2 = np.nan_to_num(action[1])
        action2 = np.clip(action2, -1, 1)
        action3 = np.nan_to_num(action[2])
        action3 = np.clip(action3, -1, 1)

        reward = 0
        self.generation += 1
        phi1 = (action1+1) + 1  # (between 1 and 3)
        phi2 = (action2+1) + 1  # (between 1 and 3)
        inertia_weight = (action3+1) * 0.3 + 0.3

        offspring = list(map(self.toolbox.clone, self.population))

        for ind in offspring:
            # Use a random solution from the HoF's Pareto front as the global best.
            self.toolbox.update(ind, random.choice(self.hof.items), phi1=phi1, phi2=phi2, inertia_weight=inertia_weight)

        for ind in offspring:
            ind.fitness.values = self.toolbox.evaluate(ind)

            # Update the particle's personal bests, keeping non-dominant solutions.
            new_best = [ind] + ind.best  # Add current state to personal bests.
            ind.best = tools.sortNondominated(new_best, len(new_best), first_front_only=True)[0]

        self.hof.update(offspring)
        self.bounds = update_bounds(self.bounds, [ind.fitness.values for ind in offspring])

        self.population = self.toolbox.select(self.population + offspring, self.population_size)

        # Select next generation population
        self.solution_space = np.array([list(ind.fitness.values) for ind in self.population]).astype(np.float32)
        self.solution_spaces.append(self.solution_space)

        if not self.save_results:
            episode_hv = compute_hypervolume(self.population, self.nr_objectives, self.reference_point)

            if self.best_hv < episode_hv:
                current_gap = (episode_hv - self.initial_hv) / (self.ideal_hv - self.initial_hv) * 100
                previous_gap = (self.best_hv - self.initial_hv) / (self.ideal_hv - self.initial_hv) * 100
                reward = round(((self.reward_factor * current_gap) ** 2) - ((self.reward_factor * previous_gap) ** 2), 1)
                print(reward)
                self.best_hv = episode_hv

        if self.generation == self.max_generations:
            self.done = True

            if self.save_results:
                print('saving results')
                self.save_result()

        # Return new state, reward, done, and optional info
        return self._observe(), reward, self.done, None, {}

    def sample(self):
        """
        Sample random actions and run the environment
        """
        for episode in range(5):
            start_time = time.time()

            print("start episode: ", episode)
            _, _ = self.reset()
            while True:
                action = env.action_space.sample()  # Take random action
                graph, reward, done, _, _ = env.step(action)
                if done:
                    end_time = time.time()  # End time for the episode
                    duration = end_time - start_time  # Calculate the duration
                    print(f"Episode {episode} completed in {duration:.2f} seconds")
                    break

    def save_result(self):
        output_dir = self.folder
        pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

        # Ensure that exp_name includes a slash (/) if needed
        exp_name = self.exp_name.strip("/")

        results = {}
        results['problem_instance'] = self.problem_instance
        results['hypervolume'] = compute_hypervolume(self.hof, self.nr_objectives, self.reference_point)
        print('hv:', results['hypervolume'], 'ref_point:', self.reference_point, len(self.hof))

        results_csv_path = os.path.join(output_dir, f'{exp_name}_results.csv')
        df = pd.DataFrame.from_dict(results, orient='index').T
        pd.DataFrame(df).to_csv(results_csv_path, index=False)
        print('results saved :-)')


# --------------------------------------------------------------------------------------------------------------------

if __name__ == "__main__":
    config_filepath = "configs/config_routing_pso.toml"
    with open(config_filepath, 'r') as toml_file:
        parameters = toml.load(toml_file)
    env = routingEnvPSO(parameters)
    env.sample()