from scipy.optimize import linprog


def solve_contract(outcome_dist, neg_cost, action, delta=0):
    c = outcome_dist[action]
    A_ub = outcome_dist - outcome_dist[action].reshape(1, -1)
    neg_cost = neg_cost.copy()
    neg_cost[action] -= delta
    b_ub = neg_cost[action] - neg_cost
    res = linprog(c, A_ub=A_ub, b_ub=b_ub, method='highs')
    return res


def correct_delta(delta, action, q_values_p):
    ### do not overpay if not worth it compared to the worst action
    delta = min(delta, q_values_p[action] - q_values_p.min())
    return delta


def create_wandb_tags_and_config(policy):
    fraction_non_trivial_actions = policy.env.root.actions_count[:-1].sum() / policy.env.root.actions_count.sum()
    config = {
        'inp_size': policy.inp_size,
        'hid_size': policy.hid_size,
        'out_size': policy.out_size,
        'n_hid_layers': policy.n_hid_layers,
        'gamma': policy.gamma,
        'lr_start': policy.lr_start,
        'lr_end': policy.lr_end,
        'outcome_dist_known': policy.outcome_dist_known,
        'eps_start': policy.eps_start,
        'eps_end': policy.eps_end,
        'batch_size': policy.batch_size,
        'n_batches': policy.n_batches,
        'n_interactions': policy.n_interactions,
        'n_warm_start_batches': policy.n_warm_start_batches,
        'target_update_freq': policy.target_update_freq,
        'log_freq': policy.log_freq,
        'delta': policy.delta,
        'val_optimal': policy.val_optimal,
        'depth': policy.env.depth,
        'n_states': policy.env.n_states,
        'fraction_non_trivial_actions': fraction_non_trivial_actions,
        'optimal_utility_principal': policy.env.root.utility_principal,
        'optimal_utility_agent': policy.env.root.utility_agent,
        'env_seed': policy.env.seed,
    }

    tags = [
        'tree',
        'train_val' if not policy.val_optimal else 'val',
    ]

    return config, tags
