import torch
import numpy as np

import sys

import gurobipy as gp
from gurobipy import GRB


def find_subset(problem,
        eps: float=1e-2, check_trg_lt_eps: bool=False,
        num_threads: int=0, debug: bool=False):
    problem_idx, (srcs, trg) = problem
    srcs = srcs.tolist()
    trg = float(trg)
    n = len(srcs)
    mask = np.zeros(n)
    if check_trg_lt_eps and (abs(trg) <= eps):    # check if the magnitude of w is less than eps
        return problem_idx, mask, 1., abs(trg)

    attempts = 0
    while True:
        try:
            with gp.Env() as env:
                env.setParam('OutputFlag', 0)
                env.start()
                with gp.Model(env=env) as m:
                    m.Params.OutputFlag = debug
                    m.Params.Seed = np.random.randint(0, 2000000)
                    m.setParam(GRB.Param.Threads, num_threads)

                    x = m.addVars(n, vtype=GRB.BINARY)
                    z = m.addVar(vtype=GRB.CONTINUOUS)
                    m.setObjective(z, GRB.MINIMIZE)
                    m.addConstr(trg - x.prod(srcs) <= z)
                    m.addConstr(-trg + x.prod(srcs) <= z)
                    #m.addConstr(trg - x.prod(srcs) <= eps)
                    #m.addConstr(-trg + x.prod(srcs) <= eps)
                    m.Params.MIPGap = 0.01
                    m.optimize()

                    if m.status == 2:   # feasible solution found
                        for i in range(n):
                            mask[i] = int(round(x[i].x) > 0)
                        subset_sum = (np.array(srcs) * mask).sum()
                        abs_error = float(abs(subset_sum - trg))
                        rel_error = float(abs_error / abs(trg))

                        if debug:     # print verbose information
                            num_used = mask.sum()
                            print('\n' + '-' * 96)
                            print('\nNumber of elements in subset:', num_used)
                            print('\nSubset sum:', subset_sum, 'is approximately equal to', trg)
                            print('\nDifference between subset sum and w:', abs_error, ', epsilon =', eps)
                            print('This difference is less than epsilon:', abs_error <= eps)
                    else:
                        raise RuntimeError(
                            f'Feasible solution not found for weight value {trg} and coefficients:\n{srcs}\n'+
                            'Try increasing c, increasing epsilon, or both.')
        except gp.GurobiError as e:
            attempts += 1
            if attempts >= 5:
                raise
            print(f'Caught gurobi error: {e}\n...Trying again (times {attempts+1})')
            continue
        else:
            break

    return problem_idx, mask, rel_error, abs_error
