import numpy as np
import sys
from ortools.algorithms.python import knapsack_solver


def multidimensional_knapsack(values, weights, capacities):
    n = len(values)
    m = len(capacities)
    
    # Create a DP table with m+1 dimensions
    dp = {}
    
    def get_dp(i, current_capacities):
        if i == 0:
            return 0
        if (i, tuple(current_capacities)) in dp:
            return dp[(i, tuple(current_capacities))]
        
        # Don't include item i
        max_value = get_dp(i - 1, current_capacities)
        
        # Check if item i can be included
        can_include = True
        for j in range(m):
            if current_capacities[j] < weights[i - 1][j]:
                can_include = False
                break
        
        # Include item i if it fits
        if can_include:
            new_capacities = [current_capacities[j] - weights[i - 1][j] for j in range(m)]
            max_value = max(max_value, values[i - 1] + get_dp(i - 1, new_capacities))
        
        dp[(i, tuple(current_capacities))] = max_value
        return max_value

    initial_capacities = list(capacities)
    return get_dp(n, initial_capacities)


if __name__ == '__main__':
    # values = [60, 100, 120]
    # weights = [[10, 1], [20, 2], [30, 3]] # Each item has weight for dimension 1 and 2
    # capacities = [50, 5] # Capacity for dimension 1 and 2
    # max_value = multidimensional_knapsack(values, weights, capacities)
    # print("Maximum value:", max_value) # Output: 220

    solver = knapsack_solver.KnapsackSolver(
        # knapsack_solver.SolverType.KNAPSACK_MULTIDIMENSION_BRANCH_AND_BOUND_SOLVER,
        knapsack_solver.SolverType.KNAPSACK_MULTIDIMENSION_CBC_MIP_SOLVER,
        "KnapsackExample",
    )
    factor = 10000

    dirname = 'dataset_mcts'
    problem_size = int(sys.argv[1])

    raw_data = np.load(f'{dirname}/test{problem_size}_dataset.npz')
    values_all, weights_all = raw_data['prizes'], raw_data['weights']
    data = []
    n_instance = weights_all.shape[0]
    constraints = np.ones(weights_all.shape[-1])
    for i in range(n_instance):
        values = values_all[i]
        weights = weights_all[i].transpose()
        data.append((values, weights, constraints))

    total_value = 0
    for i, (values, weights, constraints) in enumerate(data):
        # opt = multidimensional_knapsack(values.tolist(), weights.tolist(), constraints.tolist())
        c = (constraints*factor).astype(int).tolist()
        w = (weights*factor).astype(int).tolist()
        v = (values*factor).astype(int).tolist()
        solver.init(v, w, c)
        opt = solver.solve()
        total_value += opt
        print(f"{i+1}/{n_instance}: OPT = {opt / factor}", flush=True)
        # break
    print(f"problem size = {problem_size}: OPT = {total_value / n_instance / factor}")
