import torch
import numpy as np

import sys

from ortools.linear_solver import pywraplp


def find_subset(problem,
        eps: float=1e-2, check_trg_lt_eps: bool=False,
        num_threads: int=0, timeout: int=300, debug: bool=False):
    problem_idx, (srcs, trg) = problem
    srcs = srcs.tolist()
    trg = float(trg)
    n = len(srcs)
    mask = np.zeros(n)
    if check_trg_lt_eps and (abs(trg) <= eps):    # check if the magnitude of w is less than eps
        return problem_idx, mask, 1., abs(trg)

    solver = pywraplp.Solver.CreateSolver('SCIP')
    solverParams = pywraplp.MPSolverParameters()
    solverParams.SetDoubleParam(solverParams.RELATIVE_MIP_GAP, 0.05)
    infinity = solver.infinity()

    x = n * [None, ]
    for i in range(n):
        x[i] = solver.BoolVar(f'x_{i}')
    z = solver.NumVar(0.0, infinity, 'z')

    solver.Add(sum(x[i] * srcs[i] for i in range(n)) - trg <= z)
    solver.Add(trg - sum(x[i] * srcs[i] for i in range(n)) <= z)
    solver.Minimize(z)

    solver.SetNumThreads(num_threads)
    solver.set_time_limit(timeout)
    if not debug:
        solver.SuppressOutput()
    status = solver.Solve(solverParams)

    for i in range(n):
        mask[i] = int(x[i].solution_value())
    subset_sum = (np.array(srcs) * mask).sum()
    abs_error = float(abs(subset_sum - trg))
    rel_error = float(abs_error / abs(trg))

    if debug:     # print verbose information
        num_used = mask.sum()
        print('\n' + '-' * 96)
        print('\nNumber of elements in subset:', num_used)
        print('\nSubset sum:', subset_sum, 'is approximately equal to', trg)
        print('\nDifference between subset sum and w:', abs_error, ', epsilon =', eps)
        print('This difference is less than epsilon:', abs_error <= eps)

    return problem_idx, mask, rel_error, abs_error
