import numpy as np
import pandas as pd
from numpy import linalg as LA

import torch
from torch import nn
import torch.optim as optim
from torch.autograd import Variable
from carla.recourse_methods.processing import reconstruct_encoding_constraints

from carla import log


# def hyper_sphere_coordindates(n_search_samples, instance, high, low, p_norm=2):

#     # Implementation follows the Random Point Picking over a sphere
#     # The algorithm's implementation follows: Pawelczyk, Broelemann & Kascneci (2020);
#     # "Learning Counterfactual Explanations for Tabular Data" -- The Web Conference 2020 (WWW)
#     # It ensures that points are sampled uniformly at random using insights from:
#     # http://mathworld.wolfram.com/HyperspherePointPicking.html

#     # This one implements the growing spheres method from
#     # Thibaut Laugel et al (2018), "Comparison-based Inverse Classification for
#     # Interpretability in Machine Learning" -- International Conference on Information Processing
#     # and Management of Uncertainty in Knowledge-Based Systems (2018)

#     """
#     :param n_search_samples: int > 0
#     :param instance: numpy input point array
#     :param high: float>= 0, h>l; upper bound
#     :param low: float>= 0, l<h; lower bound
#     :param p: float>= 1; norm
#     :return: candidate counterfactuals & distances
#     """

#     delta_instance = np.random.randn(n_search_samples, instance.shape[1])
#     dist = np.random.rand(n_search_samples) * (high - low) + low  # length range [l, h)
#     norm_p = LA.norm(delta_instance, ord=p_norm, axis=1)
#     d_norm = np.divide(dist, norm_p).reshape(-1, 1)  # rescale/normalize factor
#     delta_instance = np.multiply(delta_instance, d_norm)
#     candidate_counterfactuals = instance + delta_instance

#     return candidate_counterfactuals, dist



def get_recourse(instance, 
                 step_direction, 
                 step_indicator, 
                 stepSize,
                 actionable_Featrues
                 ):
    step_direction = np.multiply(step_indicator, step_direction)
    action = np.multiply(step_direction, stepSize)
    action = np.multiply(actionable_Featrues, action)
    recourse = instance - action
    return recourse



def get_cost(x, 
             action, 
             lower_bounds, 
             upper_bounds,
             interpolator_list
             ):

    cost = []
    for k, i, a, lb, ub in zip(np.arange(len(x)), x, action, lower_bounds, upper_bounds):
        i = min(ub, max(lb, i))

        if i == a: 
            c = 0
        else:

            if a < i:
                a, i = i, a

            buffer = 1e-4
            i = i + buffer
            a = a - buffer
            
            try:
                i_percentile = interpolator_list[k](i)
                a_percentile = interpolator_list[k](a)
            except:
                i_percentile = 0 + 1e-4
                a_percentile = 1 - 1e-4

            c = np.abs(
                    np.log((
                            1.0 - a_percentile
                        )/(
                            1.0 - i_percentile
                        )
                    )
                ) + 1e-4
        cost.append(c)

    return np.array(cost, dtype=float)
    
    
    

def softmax(x):
    y = np.exp(x - np.max(x))
    f_x = y / np.sum(np.exp(x))
    return f_x
    
    
    
def get_step_indicator(instance,
                       instance_new, 
                       stepSize, 
                       step_direction, 
                       #interpolator_list, 
                       pref_vec,
                       actionable_Featrues, 
                       ):
    instance_candidate = instance_new + np.multiply(stepSize, step_direction)
    
    cost = []
    for i, a in zip(instance, instance_candidate):
        # i_percentile = i + 1e-4
        # a_percentile = a - 1e-4
        # if i == a: 
        #     c = 0
        # else:
        c = np.abs(
                np.log((
                        1.0 - i + 1e-4
                    )/(
                        1.0 - a - 1e-4
                    )
                )
            ) + 1e-4 
                
        c = c if c > 0 else 0
        cost.append(c if c else 0)
        
    # cost = get_cost(instance, 
    #                 # x_prime + torch.tensor(np.multiply(stepSize, step_direction)), 
    #                 instance_new + np.multiply(stepSize, step_direction), 
    #                 lower_bounds, upper_bounds,
    #                 #interpolator_list
    #                 )
    

    L = np.abs(np.divide(pref_vec, cost, out=np.zeros_like(cost), where=cost!=0))
    L = np.array([abs(p/c) if c > 0 else 0 for p, c in zip(pref_vec, cost)])
    
    step_indicator_value = (L - np.min(L))/(np.max(L) - np.min(L))
    
    lambda_scale = 6.5
    step_indicator_value = lambda_scale * step_indicator_value
    
    step_soft = softmax(np.multiply(step_indicator_value, actionable_Featrues))
    
    # print('cost', cost)
    # print('pref_vec', pref_vec)
    # print('L', L)
    # print('step_indicator_value', step_indicator_value)
    # print('step_soft',step_soft)

    t = 0
    iterations = 100
    while t < iterations:
        t += 1
        step_indicator = np.random.binomial(1, step_soft)
        # print('step_indicator',step_indicator)
        step_indicator = np.multiply(actionable_Featrues, step_indicator)

        if sum(step_indicator) > 0: 
            break
            
    return step_indicator, cost
    
    
def bound_recourse(instance,
                   recourse, 
                   lower_bounds, 
                   upper_bounds
                   ):
    # clamp the suggested recourse to the predefined lower and upper bounds
    recourse_clamp = np.clip(recourse, a_min = lower_bounds, a_max = upper_bounds)
    return recourse_clamp
    recourse_clamp = []
    for f, a, lb, ub in zip(instance, recourse, lower_bounds, upper_bounds):
        # print(a,f)
        if a != f:
            recourse_clamp.append(max(min(a, 1), 0))
            # np.clip(a, a_min = lower_bounds[i], a_max = upper_bounds[i]))
        else:
            recourse_clamp.append(a)
    # print('recourse_clamp', recourse_clamp)
    return recourse_clamp

def user_preferred_steps(
    instance,
    y_target,
    keys_mutable, keys_immutable,
    continuous_cols, binary_cols,
    feature_order,
    model,
    pref_vec, step_Size,
    actionable_features,
    lower_bounds, upper_bounds,
    cat_feature_indices,
    binary_cat_features: bool = True,
    n_search_samples=1000,
    p_norm=2,
    step=0.2,
    max_iter=1000,
):

    """
    :param instance: df
    :param step: float > 0; step_size for growing spheres
    :param n_search_samples: int > 0
    :param model: sklearn classifier object
    :param p_norm: float=>1; denotes the norm (classical: 1 or 2)
    :param max_iter: int > 0; maximum # iterations
    :param keys_mutable: list; list of input names we can search over
    :param keys_immutable: list; list of input names that may not be searched over
    :return:
    """  #

    device = "cuda" if torch.cuda.is_available() else "cpu"
    # returns counterfactual instance
    # torch.manual_seed(1111)

    instance = torch.from_numpy(np.array([instance.to_numpy()])).float().to(device)
    y_target = torch.tensor([y_target]).float().to(device)
    
    instance_new = Variable(instance.clone(), requires_grad=True)
    
    instance_new_enc = reconstruct_encoding_constraints(
        instance_new, cat_feature_indices, binary_cat_features
    )
    # instance_new = Variable(instance_new_enc().detach().numpy(), requires_grad=True)

    instance_new = torch.tensor(instance_new_enc.detach().numpy(), requires_grad=True).float().to(device)
    
    step = 0
    lr = 1e-2
    threshold = 0.5
    MAX_ITER = 1000
    # softmax = nn.Softmax()
    
    optimizer = optim.Adam([instance_new], lr, amsgrad=True)
    criterion = torch.nn.MSELoss()
    f_x_new = model(instance_new)
    loss = criterion(f_x_new, y_target) 
    gradient = torch.autograd.grad(loss, instance_new, retain_graph=True)[0][0]
    direction = np.sign(gradient).detach().numpy()
    # print(instance_new)
    # print(direction)
    # ## NEW ADDITION (if gradient goes outside of binary bounds, set direction to zero (Bad idea? Or maybe flip it?))
    # for i, feature in enumerate(feature_order):
    #     if feature in binary_cols:
    #         if (instance_new[0][i] == 0 and direction[i] == -1) or (instance_new[0][i] == 1 and direction[i] == 1):
    #             direction[i] = 0
    # print(direction)

    # pref_vec = np.array([1] * len(pref_vec))

    # print('torch_model', model)
    # print('instance_new_enc', instance_new)
    # print('f_x_new', f_x_new)
    # print('loss', loss)
    # print('gradient', gradient)
    # print('direction', direction)
    # print('actionable_features', actionable_features)
    # print('step_Size', step_Size)
    # print('pref_vec', pref_vec)
    recourse_candidates = []
    step_indicator_full = [0]*len(instance.detach().numpy()[0])
    cost_full = [0]*len(instance.detach().numpy()[0])
    if model(instance)[0][1] >= threshold:
        return list(instance.detach().numpy()[0]), step_indicator_full, cost_full, recourse_candidates

    # print('before model(instance)', model(instance))
    #print('instance', instance_new.detach().numpy()[0])
    recourse_found_flag = False
    while step < MAX_ITER:
        step += 1
        step_indicator = np.array([1] * len(direction))
        step_direction = np.multiply(direction, actionable_features)
        
        step_indicator, cost_full = get_step_indicator(instance.detach().numpy()[0], 
                                            instance_new.detach().numpy()[0], 
                                            step_Size, step_direction, #interpolator_list, 
                                            pref_vec, actionable_features)
        instance_new = get_recourse(instance_new.detach().numpy()[0], direction, step_indicator, step_Size, actionable_features)
        # print(direction)
        # print(instance_new)
        ## NEW ADDITION (if new instance has a reached binary boundary, set direction to zero)
        # for i, feature in enumerate(feature_order):
        #     if feature in binary_cols:
        #         if (instance_new[i] == 0 and direction[i] == -1) or (instance_new[i] == 1 and direction[i] == 1):
        #             direction[i] = 0
        instance_new = bound_recourse(instance, instance_new, lower_bounds, upper_bounds)
        instance_new = torch.from_numpy(np.array([instance_new])).float().to(device)
        
                    
        # print('model(instance_new)[0][1] ', model(instance_new)[0][1])
        step_indicator_full = step_indicator_full + step_indicator
        recourse_candidates.append(np.asarray(instance_new.flatten()))
        
        if model(instance_new)[0][1] >= threshold: 
            recourse_found_flag = True
            binary_labels_flip_index_list = get_binary_label_flip_index_list(instance.flatten(), instance_new.flatten())
            break
        optimizer = optim.Adam([instance_new], lr, amsgrad=True)
        criterion = torch.nn.MSELoss()
        f_x_new = model(instance_new)
        loss = criterion(f_x_new, y_target) 
        direction = np.sign(gradient).detach().numpy()
    
    if recourse_found_flag and len(binary_labels_flip_index_list) > 0 and len(recourse_candidates)>1:
        instance_new = cost_recovery_linear_search(recourse_candidates, binary_labels_flip_index_list, threshold, model)
    # print(recourse_candidates)
    # return list(instance_new.detach().numpy()[0]), step_indicator_full, cost_full
    return list(instance_new.detach().numpy().flatten()), step_indicator_full, cost_full, recourse_candidates
    
def get_binary_label_flip_index_list(instance, recourse):
    binary_labels_flip_index_list = []
    # print(instance)
    # print(recourse)
    for i, x, a in zip(range(len(instance)), instance, recourse):
        # print(i)
        # print(x)
        # print(a)
        if abs(x - a) == 1:
            binary_labels_flip_index_list.append(i)
    return binary_labels_flip_index_list
    
def cost_recovery_linear_search(recourse_candidates, binary_labels_flip_index_list, threshold, model):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    recourse_candidates_binary_flipped = np.array(recourse_candidates)
    recourse_candidates_binary_flipped[:, binary_labels_flip_index_list] = recourse_candidates[-1][binary_labels_flip_index_list]
    # print('Total number candidates: {}'.format(len(recourse_candidates_binary_flipped)))
    recourse_prev = recourse_candidates_binary_flipped[-1]
    # for i in range(len(recourse_candidates_binary_flipped) - 1, -1, -1):
    for i in range(0,len(recourse_candidates_binary_flipped)):
        recourse_curr = recourse_candidates_binary_flipped[i]
        # print(i)
        # print(model(torch.from_numpy(recourse_curr).float().to(device)))
        if model(torch.from_numpy(recourse_curr).float().to(device)).detach().numpy()[1] >= threshold:
            return torch.from_numpy(recourse_curr).float().to(device)
        recourse_prev = recourse_curr
    return torch.from_numpy(recourse_candidates_binary_flipped[-1]).float().to(device)