import heapq
import math
import numpy as np
from codebleu.syntax_match import calc_syntax_match


def population_management(pop, size, reduc_info=None, multi_obj=False):
    if multi_obj:  # TODO: generalize to more objectives
        pop = [individual for individual in pop if (individual['objective'] is not None and not math.isnan(individual['objective'])) ]
        if size > len(pop):
            size = len(pop)
        unique_pop = [] 
        objs = [ind['objective'] for ind in pop]
        unique_objectives = set(objs)
        for obj in unique_objectives:
            inds = [ind for ind in pop if ind['objective'] == obj]
            problems = set([ind['problem'] for ind in inds])
            for prob in problems:
                individuals_with_prob = [ind for ind in inds if ind['problem'] == prob]
                unique_pop.append(min(individuals_with_prob, key=lambda x: x['runtime']))  # if same fitness (and same reduction), choose the one with fastest runtime

        pop_sorted = sortby_dominance_dissimilarity(unique_pop)
        pop_new = pop_sorted[:size]
        if reduc_info is None:
            return pop_new
        else:  # NOTE: only do single-obj for reduction
            reduc_pop = reduc_info[0]
            for r, reduc in enumerate(reduc_pop):
                individuals_with_op = [ind for ind in pop_sorted if ind['problem'] == reduc['problem']][:reduc_info[1]]
                reduc_pop[r]['objective'] = np.round(np.mean([ind['objective'] for ind in individuals_with_op]), 5)
                reduc_pop[r]['runtime'] = np.round(np.mean([ind['runtime'] for ind in individuals_with_op]), 5)
                # if math.isnan(reduc_pop[r]['objective']):
                #     print('\n================')
                #     print(unique_pop)
                #     print(r, individuals_with_op)
            return pop_new, reduc_pop
    else:
        pop = [individual for individual in pop if (individual['objective'] is not None and not math.isnan(individual['objective'])) ]
        if size > len(pop):
            size = len(pop)
        unique_pop = [] 
        objs = [ind['objective'] for ind in pop]
        unique_objectives = set(objs)
        for obj in unique_objectives:
            inds = [ind for ind in pop if ind['objective'] == obj]
            problems = set([ind['problem'] for ind in inds])
            for prob in problems:
                individuals_with_prob = [ind for ind in inds if ind['problem'] == prob]
                unique_pop.append(min(individuals_with_prob, key=lambda x: len(x['code'])))  # if same fitness (and same reduction), choose the one with shortest code
        # Delete the worst individual
        #pop_new = heapq.nsmallest(size, pop, key=lambda x: x['objective'])
        pop_sorted = sorted(unique_pop, key=lambda x: x['objective'], reverse=True)
        pop_new = pop_sorted[:size]
        if reduc_info is None:
            return pop_new
        else:  # if reduc_info provided, also update the fitness scores for the reductions
            reduc_pop = reduc_info[0]
            for r, reduc in enumerate(reduc_pop):
                individuals_with_op = [ind for ind in pop_sorted if ind['problem'] == reduc['problem']][:reduc_info[1]]
                reduc_pop[r]['objective'] = np.mean([ind['objective'] for ind in individuals_with_op])
            return pop_new, reduc_pop


def sortby_dominance_dissimilarity(pop):
    pop_size = len(pop)
    dominated_counts = np.zeros((pop_size, pop_size))
    for i in range(pop_size):
        for j in range(i + 1, pop_size):
            ind_i_score = [pop[i]['objective'], -pop[i]['runtime']]  # the higher score the better
            ind_j_score = [pop[j]['objective'], -pop[j]['runtime']]
            # print(pop[i]['code'], pop[j]['code'])
            if (np.array(ind_i_score) >= np.array(ind_j_score)).all():  # j is dominated by i
                dominated_counts[i, j] = -calc_syntax_match([pop[i]['code']], pop[j]['code'], 'python')
            elif (np.array(ind_j_score) >= np.array(ind_i_score)).all():  # i is dominated by j
                dominated_counts[j, i] = -calc_syntax_match([pop[j]['code']], pop[i]['code'], 'python')
    dominated_counts_ = dominated_counts.sum(0)
    # print(-dominated_counts_)
    pop_sorted = [pop[i] for i in np.argsort(-dominated_counts_)]
    return pop_sorted


if __name__ == '__main__':
    import json

    # for i in range(10):
    #     print(i+1)
    #     path = '../../exp_tsp/50_12/results/pops_mcts_evored_2'
    #     with open(f'{path}/population_generation_{i+1}.json') as file:
    #         pop = json.load(file)
    #     sortby_dominance_dissimilarity(pop)
    pop = [
     {
          "problem": "Problem B1 involves selecting a bin by computing a score based on the current load and remaining capacity, then directly choosing the bin with the highest score.",
          "algorithm": "Filter out bins that cannot fit the item, then for each feasible bin compute a normalized weighted score of its current load and inverse remaining capacity, and finally select the bin with the highest score.",
          "code": "\nfrom typing import Tuple\nimport numpy as np\n\ndef solve_B(input_B: Tuple[float, np.ndarray, np.ndarray]) -> Tuple[int, np.ndarray]:\n    '''\n    Args:\n        input_B (Tuple[float, np.ndarray, np.ndarray]): A tuple representing the input for Problem B,\n            where:\n              - The first element is a float representing the item size.\n              - The second element is a numpy array (np.ndarray) of floats representing the current load in each bin.\n              - The third element is a numpy array (np.ndarray) of floats representing the remaining capacities of each bin.\n\n    Returns:\n        solution_B (Tuple[int, np.ndarray]): A tuple representing the solution for Problem B,\n            where:\n              - The first element is an int representing the selected bin index.\n              - The second element is a numpy array (np.ndarray) of floats storing the priority scores for each bin.\n    '''\n    item_size, current_load, remaining_capacity = input_B\n    n_bins = current_load.shape[0]\n    \n    # Determine which bins are feasible.\n    valid_mask = remaining_capacity >= item_size\n    \n    # Compute normalized values for current load and remaining capacity.\n    # Avoid division by zero by setting a minimum denominator of 1.\n    max_load = current_load.max() if current_load.max() > 0 else 1.0\n    max_remaining = remaining_capacity.max() if remaining_capacity.max() > 0 else 1.0\n    \n    norm_load = current_load / max_load\n    norm_remaining = remaining_capacity / max_remaining\n    \n    # Compute score as the weighted sum: 0.5 * normalized load + 0.5 * (1 - normalized remaining capacity)\n    scores = 0.5 * norm_load + 0.5 * (1 - norm_remaining)\n    \n    # Invalidate bins that cannot fit the item by assigning a very low score.\n    scores[~valid_mask] = -np.inf\n    \n    # Select the bin with the highest score. If none are valid, choose -1.\n    if np.all(~valid_mask):\n        selected_bin = -1\n    else:\n        selected_bin = int(np.argmax(scores))\n        \n    solution_B = (selected_bin, scores)\n    return solution_B",
          "objective": -2.19502,
          "runtime": 0.74185,
        #   "other_inf": null,
          "op": "i1"
     },
     {
          "problem": "Problem B10 involves formulating the selection as a simplified assignment problem by building a cost matrix based on the incremental load each bin would receive, then choosing the bin with the minimal incremental cost.",
          "algorithm": "The algorithm iterates over each bin's incremental load from the cost matrix, selects the bin with the minimal incremental cost, and returns its index as the assignment solution.",
          "code": "\nfrom typing import Tuple\nimport numpy as np\n\ndef solve_B(input_B: Tuple[np.ndarray]) -> Tuple[np.ndarray]:\n    '''\n    Args:\n        input_B (Tuple[np.ndarray]): A tuple containing a single numpy.ndarray of shape (M, 1) representing the cost matrix for Problem B.\n    \n    Returns:\n        Tuple[np.ndarray]: A tuple containing a single numpy.ndarray (the solution for Problem B).\n    '''\n    cost_matrix, = input_B\n    # Find the index of the bin with the minimal incremental cost\n    selected_bin = np.argmin(cost_matrix[:, 0])\n    solution = np.array([selected_bin])\n    return (solution,)",
          "objective": -2.46938,
          "runtime": 0.42109,
        #   "other_inf": null,
          "op": "i1"
     },
     {
          "problem": "Problem B1 involves selecting a bin by computing a score based on the current load and remaining capacity, then directly choosing the bin with the highest score.",
          "algorithm": "The algorithm computes a feasibility mask to filter bins that can accommodate the item, vectorizes the score calculation as the ratio of the current load to the total capacity (load plus remaining capacity) for each feasible bin (assigning \u2212\u221e for infeasible ones), and then selects the bin with the highest score using numpy\u2019s argmax function.",
          "code": "\nfrom typing import Tuple\nimport numpy as np\n\ndef solve_B(input_B: Tuple[float, np.ndarray, np.ndarray]) -> Tuple[int, np.ndarray]:\n    '''\n    Args:\n        input_B (Tuple[float, np.ndarray, np.ndarray]): A tuple representing the input for Problem B,\n            where:\n              - The first element is a float representing the item size.\n              - The second element is a numpy array (np.ndarray) of floats representing the current load in each bin.\n              - The third element is a numpy array (np.ndarray) of floats representing the remaining capacities of each bin.\n\n    Returns:\n        solution_B (Tuple[int, np.ndarray]): A tuple representing the solution for Problem B,\n            where:\n              - The first element is an int representing the selected bin index.\n              - The second element is a numpy array (np.ndarray) of floats storing the priority scores for each bin.\n    '''\n    item_size, loads, rem_caps = input_B\n    # Create a feasibility mask: only bins with enough remaining capacity are considered feasible.\n    feasible = rem_caps >= item_size\n    \n    # Compute the denominator to safely compute the ratio.\n    denom = loads + rem_caps\n    # Calculate scores using vectorized operations: load divided by (load + remaining capacity),\n    # while providing a safeguard for zero denominators.\n    with np.errstate(divide='ignore', invalid='ignore'):\n        computed_scores = np.where(denom > 0, loads / denom, 0.0)\n    \n    # For bins that are not feasible, assign a very low score to avoid selection.\n    scores = np.where(feasible, computed_scores, -np.inf)\n    \n    # Select the bin with the highest score.\n    selected_bin = int(np.argmax(scores))\n    \n    return (selected_bin, scores)",
          "objective": -2.46938,
          "runtime": 0.67157,
        #   "other_inf": null,
          "op": "i1"
     },
     {
          "problem": "Problem B6 involves using a dynamic programming approach where bin selection is carried out by recursively computing and comparing aggregate capacity metrics, thereby choosing the bin with the optimal combined score.",
          "algorithm": "Iterate from the last bin backwards updating each bin\u2019s score as the maximum between skipping it and assigning the item\u2014i.e., if the current bin can accommodate the item, compute its score as (its capacity minus the item\u2019s size) plus the future aggregated score, and then take the maximum with the score from the next bin.",
          "code": "\nfrom typing import Tuple\nimport numpy as np\n\ndef solve_B(item_size: float, bin_caps: np.ndarray, agg_caps: np.ndarray) -> np.ndarray:\n    '''\n    Args:\n        item_size (float): Size of the item to be assigned to a bin.\n        bin_caps (np.ndarray): 1D numpy array of bin capacities.\n        agg_caps (np.ndarray): 1D numpy array of aggregated bin capacities computed from right to left.\n    \n    Returns:\n        np.ndarray: An array containing the aggregated/dynamic programming scores for each bin (i.e., solution_B).\n    '''\n    n = len(bin_caps)\n    dp = np.copy(agg_caps)\n    for i in range(n - 1, -1, -1):\n        if bin_caps[i] >= item_size:\n            candidate = (bin_caps[i] - item_size) + (dp[i + 1] if i < n - 1 else 0)\n        else:\n            candidate = (dp[i + 1] if i < n - 1 else 0)\n        dp[i] = candidate if i == n - 1 else max(candidate, dp[i + 1])\n    return dp",
          "objective": -2.46938,
          "runtime": 36.90012,
        #   "other_inf": null,
          "op": "i1"
     }
    ]
    pop_new = population_management(pop, 10, reduc_info=None, multi_obj=False)
    for ind in pop_new:
        print(ind)