import pandas as pd

import dice_ml

from methods.dice.dice_wrapper import DicePyTorchWrapper
from methods.reup import bayesian_utils

def generate_recourse(x0, model, random_state, params=dict()):
    dim = x0.shape[0]
    df = params['dataframe']
    numerical = params['numerical']
    k = params["dice_params"]['k']

    full_dice_data = dice_ml.Data(dataframe=df,
                                  continuous_features=numerical,
                                  outcome_name='label')
    dice_model = dice_ml.Model(model=model, backend='PYT')
    dice = DicePyTorchWrapper(full_dice_data, dice_model)      

    df = df.drop(columns=['label'])

    recourse = dice.generate_counterfactuals(x0, total_CFs=k,
                                          desired_class="opposite",
                                          posthoc_sparsity_param=None,
                                          proximity_weight=params['dice_params']['proximity_weight'],
                                          diversity_weight=params['dice_params']['diversity_weight']) 
    
    A_0 = bayesian_utils.generate_A_0(dim)
    cost = bayesian_utils.evaluate_cost_diag(recourse.final_cfs_df_sparse[0].reshape(-1, 1), x0.reshape(-1, 1), A_0) 
    #cost = (recourse.final_cfs_df_sparse[0].reshape(-1, 1) - x0.reshape(-1, 1)).T @ A_0 @ (recourse.final_cfs_df_sparse[0].reshape(-1, 1) - x0.reshape(-1, 1))

    print(cost)
    print()

    #logging
    log_dict = {}
    log_dict['recourse'] = recourse
    log_dict['cost'] = cost
    log_dict['x_0'] = x0.reshape(-1, 1)
    log_dict['A_0'] = A_0
    
    #define the cost of counterfactual

    return recourse.final_cfs_df_sparse[0], cost, True, log_dict