
import random
import operator
import math
import collections

from deap import creator

def generate_particle(size, s_min, s_max):
    # vals = list(range(1, val_max + 1))
    # random.shuffle(vals)
    vals = [random.random() for i in range(size)]
    part = creator.Particle(vals)
    part.speed = [random.uniform(s_min, s_max) for _ in range(size)]
    part.smin = s_min
    part.smax = s_max
    return part


def remove_duplicates(vals):
    duplic = [item for item, count in collections.Counter(vals).items() if count > 1]
    uniq_part = []
    offset = 0.001
    count = [1] * len(duplic)
    for val in vals:
        if val in duplic:
            ind = duplic.index(val)
            val += offset * count[ind]
            count[ind] += 1
        uniq_part.append(val)

    return uniq_part
#
#
# Change floats to integers and deal with duplicates
def validate_particle(particle):
    unique_part = remove_duplicates(particle)
    sorted_asc = sorted(unique_part, key=float)
    validated_part = []

    if len(sorted_asc) > len(set(sorted_asc)):
        print("problem")

    for val in unique_part:
        index = sorted_asc.index(val)
        validated_part.append((index + 1))

    return validated_part


# def update_particle(part, best, phi1, phi2):
#     print(part)
#     u1 = (random.uniform(0, phi1) for _ in range(len(part)))
#     u2 = (random.uniform(0, phi2) for _ in range(len(part)))
#     # the particle's best position
#     v_u1 = map(operator.mul, u1, map(operator.sub, random.choice(part.best), part))
#     # the neighbourhood best
#     v_u2 = map(operator.mul, u2, map(operator.sub, best, part))
#     # update particle speed
#     part.speed = list(map(operator.add, part.speed, map(operator.add, v_u1, v_u2)))
#     # speed limits
#     for i, speed in enumerate(part.speed):
#         if abs(speed) < part.smin:
#             part.speed[i] = math.copysign(part.smin, speed)
#             # adjust maximum speed if necessary
#         elif abs(speed) > part.smax:
#             part.speed[i] = math.copysign(part.smax, speed)
#
#     new_part = list(map(operator.add, part, part.speed))
#     print(new_part)
#     part[:] = validate_particle(new_part)
#

def update_particle(part, best, phi1, phi2, inertia_weight):
    # Hypothetical parameters for the example
    smin = -0.1  # minimum speed
    smax = 0.1   # maximum speed

    # Generate random coefficients lists
    u1 = [random.uniform(0, phi1) for _ in range(len(part))]
    u2 = [random.uniform(0, phi2) for _ in range(len(part))]

    # Calculate velocity updates from personal and global bests
    v_u1 = [a * (b - c) for a, b, c in zip(u1, random.choice(part.best), part)]
    v_u2 = [a * (b - c) for a, b, c in zip(u2, best, part)]

    # Update particle velocity with inertia
    part.speed = [inertia_weight * spd + v1 + v2 for spd, v1, v2 in zip(part.speed, v_u1, v_u2)]

    # Apply velocity clamping
    part.speed = [max(min(spd, smax), smin) for spd in part.speed]

    # Update particle position
    new_part = [pos + spd for pos, spd in zip(part, part.speed)]

    # Ensure the new position is within bounds
    part[:] = [min(max(0, x), 1) for x in new_part]  # Clamping each position between 0 and 1


def get_subroutes(individual, truck_capacity, demands_data):
    """
    Splits a given route into subroutes based on truck capacity and customer demands.

    Parameters:
    individual (list): The list of customer IDs in the route.
    truck_capacity (int): The maximum capacity of the truck.
    demands_data (list): The list of demands for each customer.

    Returns:
    list: A list of subroutes where each subroute is a list of customer IDs.
    """

    routes = []
    sub_route = []
    vehicle_load = 0

    for customer_id in individual:
        demand = demands_data[customer_id - 1]
        if vehicle_load + demand <= truck_capacity:
            sub_route.append(customer_id)
            vehicle_load += demand
        else:
            routes.append(sub_route)
            sub_route = [customer_id]
            vehicle_load = demand

    if sub_route:
        routes.append(sub_route)

    return routes


def compute_routes_fitness(routes, dist_matrix_data, distance_depot_data):
    total_distance = 0
    longest_sub_route = 0
    for route in routes:
        route_distance = 0
        route_distance += distance_depot_data[route[0] - 1] + distance_depot_data[route[-1] - 1]
        for i in range(len(route) - 1):
            route_distance += dist_matrix_data[route[i] - 1][route[i + 1] - 1]
        total_distance += route_distance
        if route_distance > longest_sub_route:
            longest_sub_route = route_distance

    return longest_sub_route, total_distance
#
#
def eval_individual_fitness(individual, truck_capacity, dist_matrix_data, dist_depot_data, demands_data):
    sorted_indexes = sorted(range(len(individual)), key=lambda k: individual[k])
    routes = get_subroutes(sorted_indexes, truck_capacity, demands_data)
    longest_sub_route, total_distance = compute_routes_fitness(routes, dist_matrix_data, dist_depot_data)
    return (longest_sub_route, total_distance)
#