from methods.reup.chebysev import chebysev_center
from methods.reup.q_determine import exhaustive_search, find_q
from methods.reup.gd import gd
from methods.reup import bayesian_utils, q_determine, gd
import numpy as np

def generate_recourse(x0, model, random_state, params=dict()):
    #adjust the dimensionality
    x_0 = x0.reshape(-1, 1)
    dim = x_0.shape[0]

    #define search size
    dname=params["dataset_name"]
    if dname=="synthesis" or dname=="german":
        size=50
    else:
        size=25

    # General parameters
    train_data = params['train_data']
    labels = params['labels']
    cat_indices = params['cat_indices']
    feasible_set = train_data[labels == 1]
    
    train_data = np.concatenate([x0.reshape(1, -1), train_data])

    # Bayesian ReUP parameters
    prior_Sigma = np.random.normal(loc=0.0, scale=1.0, size=(dim, dim))
    prior_Sigma = prior_Sigma @ prior_Sigma.T
    CONST=4
    prior_m = dim + CONST

    #posterior inference params
    sessions = params['bayesian_reup_params']['sessions']
    iterations = params['bayesian_reup_params']['iterations']
    lr = params['bayesian_reup_params']['lr']
    
    #recourse generation params
    lr_gd = params['bayesian_reup_params']['lr_gd']

    TAU = 0.1
    A_0 = bayesian_utils.generate_A_0(dim)

    #Preference elicitation + Bayesian inference
    post_Sigma, post_m, log_dict  = q_determine.bayesian_PE(A_0, prior_Sigma, prior_m, x_0, dim, feasible_set, sessions, iterations, lr, tau=TAU, size=size)

    recourse, feasible = gd.bayesian_gd(
        post_Sigma,
        post_m,
        model,
        x0.reshape(1, -1),
        cat_indices,
        binary_cat_features=True,
        lr=lr_gd,
        lambda_param=1.0,
        n_iter=1000,
        t_max_min=1000,
        clamp=True,
    )

    #cost = bayesian_utils.evaluate_cost_diag(recourse.reshape(-1, 1), x_0, A_0)
    #cost = (recourse.reshape(-1, 1) - x_0).T @ A_0 @ (recourse.reshape(-1, 1) - x_0)
    cost = bayesian_utils.l1_norm_diag(recourse.reshape(-1, 1), x_0, A_0)
    
    #logging
    log_dict['recourse'] = recourse
    log_dict['cost'] = cost
    log_dict['x_0'] = x_0
    log_dict['A_0'] = A_0

    return recourse, cost, feasible, log_dict