import argparse
from deap import base, creator, tools
import numpy as np
import time
import json
import logging
import random
import pathlib
import os
import pandas as pd

from scheduling.helper_functions import record_stats
from routing.helper_functions import read_input_cvrp
from routing.pso.operators import eval_individual_fitness, generate_particle, update_particle
from utils import compute_hypervolume

PARAM_FILE = "configs/pso_algorithm_routing.json"
DEFAULT_RESULTS_ROOT = "./results/routing_runs"

logging.basicConfig(level=logging.INFO)


def save_results(hof, logbook, folder, exp_name, kwargs):
    output_dir = folder
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    exp_name = exp_name.strip("/")
    logbook_csv_path = os.path.join(output_dir, f'{exp_name}_logbook.csv')
    logbook_df = pd.DataFrame(logbook)
    logbook_df.to_csv(logbook_csv_path, index=False)

    hof_data = []
    for ind in hof:
        hof_data.append(ind.fitness.values)

    hof_df = pd.DataFrame(hof_data, columns=[f'Objective_{i + 1}' for i in range(len(hof_data[0]))])
    hof_csv_path = os.path.join(output_dir, f'{exp_name}_hof.csv')
    hof_df.to_csv(hof_csv_path, index=False)

    for i in range(len(hof[0].fitness.values)):
        kwargs[f'min_obj_{i}'] = min([ind.fitness.values[i] for ind in hof])
        kwargs[f'max_obj_{i}'] = max([ind.fitness.values[i] for ind in hof])

    results_csv_path = os.path.join(output_dir, f'{exp_name}_results.csv')
    df = pd.DataFrame.from_dict(kwargs, orient='index').T
    pd.DataFrame(df).to_csv(results_csv_path, index=False)


def initialize_run(**kwargs):
    nb_customers, truck_capacity, dist_matrix_data, dist_depot_data, demands_data = read_input_cvrp(kwargs['instance_file'], kwargs['problem_instance'])

    creator.create("FitnessMin", base.Fitness, weights=(-1.0, -1.0))
    creator.create("Particle", list, fitness=creator.FitnessMin, speed=list, smin=None, smax=None, best=None)

    toolbox = base.Toolbox()
    toolbox.register("particle", generate_particle, size=nb_customers, s_min=kwargs['smin'], s_max=kwargs['smax'])
    toolbox.register("population", tools.initRepeat, list, toolbox.particle)
    toolbox.register("update", update_particle, phi1=kwargs['cognitive_coef'], phi2=kwargs['social_coef'])
    toolbox.register('evaluate', eval_individual_fitness, truck_capacity=truck_capacity, dist_matrix_data=dist_matrix_data, dist_depot_data=dist_depot_data, demands_data=demands_data)

    toolbox.register("select", tools.selNSGA2)

    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean, axis=0)
    stats.register("std", np.std, axis=0)
    stats.register("min", np.min, axis=0)
    stats.register("max", np.max, axis=0)

    initial_population = toolbox.population(n=kwargs['population_size'])

    for ind in initial_population:
        ind.fitness.values = list(toolbox.evaluate(ind))
        ind.best = [ind]

    hof = tools.ParetoFront()
    hof.update(initial_population)

    return initial_population, toolbox, stats, hof


def run_algo(population, toolbox, folder, exp_name, stats, hof, **kwargs):
    start = time.time()
    gen = 0
    df_list = []
    logbook = tools.Logbook()
    logbook.header = ["gen"] + (stats.fields if stats else [])
    record_stats(gen, population, logbook, stats, kwargs['logbook'], df_list, logging)

    for gen in range(0, kwargs['ngen']):

        # Clone the offspring for the new generation.
        offspring = list(map(toolbox.clone, population))

        # Update each particle's position and velocity.
        for ind in offspring:
            # Use a random solution from the HoF's Pareto front as the global best.
            toolbox.update(ind, random.choice(hof.items), inertia_weight=kwargs['inertia_weight']*(kwargs['inertia_damping']**gen))

        for ind in offspring:
            ind.fitness.values = toolbox.evaluate(ind)

            # Update the particle's personal bests, keeping non-dominant solutions.
            new_best = [ind] + ind.best  # Add current state to personal bests.
            ind.best = tools.sortNondominated(new_best, len(new_best), first_front_only=True)[0]

        # Update the Pareto front HoF.
        hof.update(offspring)

        # Select the next generation based on Pareto dominance.
        population[:] = toolbox.select(population + offspring, len(population))

        # Record statistics for this generation.
        record_stats(gen, population, logbook, stats, True, df_list, logging)

    # Load existing reference point and compute hypervolume
    if 'cvrp_500_' in kwargs['instance_file']:
        reference_points = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_500_2_obj.json"
    elif 'cvrp_200_' in kwargs['instance_file']:
        reference_points = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_200_2_obj.json"
    elif 'cvrp_100_' in kwargs['instance_file']:
        reference_points = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_100_2_obj.json"
    elif 'cvrp_50_' in kwargs['instance_file']:
        reference_points = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_50_2_obj.json"
    elif 'cvrp_20_' in kwargs['instance_file']:
        reference_points = "/hpc/za64617/projects/GNN4APC_dev/code/src/routing/data/reference_points_20_2_obj.json"
    if os.path.isfile(reference_points):
        with open(reference_points, 'r') as file:
            reference_points = json.load(file)
            if str(kwargs['problem_instance']) in reference_points:
                reference_point = reference_points[str(kwargs['problem_instance'])]
                hypervolume = compute_hypervolume(hof, kwargs['nr_of_objectives'], list(reference_point))
                kwargs['hypervolume'] = hypervolume
            else:
                print('NO REFERENCE POINT KNOWN')

    if folder is not None:
        save_results(hof, logbook, folder, exp_name, kwargs)

    return hypervolume


def main(param_file=PARAM_FILE):
    try:
        parameters = json.load(open(param_file))
    except FileNotFoundError:
        logging.error(f"Parameter file {param_file} not found.")
        return

    folder = (
            DEFAULT_RESULTS_ROOT
            + "/"
            + str(parameters['problem_instance'])
    )

    exp_name = "rseed" + str(parameters['rseed'])
    population, toolbox, stats, hof = initialize_run(**parameters)
    run_algo(population, toolbox, folder, exp_name, stats, hof, **parameters)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Genetic Algorithm for Routing")
    parser.add_argument(
        "--config_file",
        metavar='f',
        type=str,
        nargs="?",
        default=PARAM_FILE,
        help="Path to config file",
    )

    args = parser.parse_args()
    main(param_file=args.config_file)