import autograd.numpy as np
import pickle

from methods.reup import q_determine 
from methods.reup import graph, bayesian_utils


def generate_recourse(x_0, model, params=dict()):
    #adjust the dimensionality
    x_0 = x_0.reshape(-1, 1)
    dim = x_0.shape[0]

    dname=params["dataset_name"]
    if dname=="synthesis" or dname=="german":
        size=50
    else:
        size=25

    # General parameters
    train_data = params['train_data']
    labels = params['labels']
    feasible_set = train_data[labels == 1]

    train_data = np.concatenate([x_0.reshape(1, -1), train_data])
    pos_idx = np.where(labels == 1)[0] + 1

    # Bayesian ReUP Graph parameters
    prior_Sigma = np.random.normal(loc=0.0, scale=1.0, size=(dim, dim))
    prior_Sigma = prior_Sigma @ prior_Sigma.T
    CONST=4
    prior_m = dim + CONST

    sessions = params['bayesian_reup_params']['sessions']
    iterations = params['bayesian_reup_params']['iterations']
    lr = params['bayesian_reup_params']['lr']
    n_neighbors = params['bayesian_reup_params']['n_neighbors']
    
    TAU = 0.1
    A_0 = bayesian_utils.generate_A_0(dim)

    #Preference elicitation + Bayesian inference
    post_Sigma, post_m, log_dict  = q_determine.bayesian_PE(
                                    A_0, 
                                    prior_Sigma, 
                                    prior_m, 
                                    x_0, 
                                    dim, 
                                    feasible_set, 
                                    sessions, 
                                    iterations, 
                                    lr, 
                                    tau=TAU, 
                                    size=size)

    # Graph-based recourse generation
    graph_opt = graph.bayesian_build_graph(train_data, post_Sigma, post_m, n_neighbors)
    path = graph.shortest_path_graph(graph_opt, pos_idx)[2]
    recourse = train_data[path[-1]]   
    cost = graph.eval_cost(A_0, train_data, path)
    feasible = True

    #logging
    log_dict['recourse'] = recourse
    log_dict['path'] = path
    log_dict['cost'] = cost
    log_dict['x_0'] = x_0
    log_dict['A_0'] = A_0
    
    return recourse, cost, feasible, log_dict