import pandas as pd
import numpy as np
import random

from copy import deepcopy

from data_utils import normalize


def perturb_datapoints(dataset, means, std, datapoint, X):
        
    if dataset == 'heloc':
        categorical_indices = []
        numerical_indices = [0,1,2,3]
        pos_numeric = []
        categorical_options = {
        }
    
    elif dataset == 'adult':
        categorical_indices = [0, 2, 3, 6, 7]
        numerical_indices = [1, 4, 5]
        pos_numeric = [1, 4]
        categorical_options = {
            0: ['yes', 'no'],
            2: ['yes', 'no'],
            3: ['yes', 'no'],
            6: ['yes', 'no'],
            7: ['yes', 'no'],
        }
    
    elif dataset == 'german_credit':
        categorical_indices = [0, 2, 3]
        numerical_indices = [1, 4]
        pos_numeric = []
        categorical_options = {
            0: ['>= 200 DM / salary for at least 1 year', 'no checking account', '< 0 DM', '0 <= ... <= 200 DM'],
            2: ['No credits taken/all credits paid back duly', 'All credits at this bank paid back duly', 'Existing credits paid back duly till now', 'Critical account/other credits elsewhere', 'Delay in paying off in the past'],
            3: ['Furniture/equipment', 'Others', 'Car (Used)', 'Car (new)', 'Retraining', 'Repairs', 'Domestic Applicances', 'Business', 'Radio/television', 'Vacation'],
        }
        
        
    perturbed_datapoint = deepcopy(list(datapoint))
        
        
    while datapoint.tolist() == perturbed_datapoint:
            
        # Create copies of the datapoints to perturb
        perturbed_datapoint = deepcopy(list(datapoint))

        # Define the weights for choosing the number of features to perturb
        weights = [0.8, 0.1, 0.05, 0.05]  # Highest probability for 1, lowest for 4

        # Choose the number of features to perturb based on the defined weights
        num_features_to_perturb = random.choices([1, 2, 3, 4], weights=weights, k=1)[0]

        # Randomly select the features to perturb
        features_to_perturb = random.sample(range(len(datapoint)), num_features_to_perturb)

        # Do perturbation
        for feature in features_to_perturb:
            if feature in categorical_indices:
                current_value = perturbed_datapoint[feature]
                possible_values = [v for v in categorical_options[feature] if v != current_value]
                perturbed_datapoint[feature] = random.choice(possible_values)

            elif feature in numerical_indices:
                original_value = perturbed_datapoint[feature]

                # select preset perturbation values based on std, then randomly choose one
                pert_values = [-round(std[feature]), -round(std[feature]/2), round(std[feature]/2), round(std[feature])]
                
                pert_idx = random.randint(0, 3)
                if feature in pos_numeric:
                    pert_idx = 2 + (pert_idx % 2)  # map 0,1,2,3 to 2,3,2,3 so feature only perturbs up

                perturbation = pert_values[pert_idx]

                # Make sure they don't go negative
                final_value = original_value + perturbation
                if final_value < 0.:
                    final_value = 1.

                # Make sure they don't go past max value
                if final_value > X.values.T[feature].max():
                    final_value = X.values.T[feature].max()

                perturbed_datapoint[feature] = int(final_value)

        for feature in numerical_indices:
            perturbed_datapoint[feature] = int(perturbed_datapoint[feature] )
            datapoint[feature] = int(datapoint[feature] )

    return list(datapoint), perturbed_datapoint



def perturb_datapoint(dataset, means, std, datapoint, feature, desiderata=1):
    """
    Specifically for cost function evaluation ONLY
    """    
        
    # Define the indices of categorical and numerical features
    if dataset == 'heloc':
        categorical_indices = []
        numerical_indices = [0,1,2,3]
        pos_numeric = []
        categorical_options = {
        }
    
    elif dataset == 'adult':
        categorical_indices = [0, 2, 3, 6, 7]
        numerical_indices = [1, 4, 5]
        pos_numeric = [1, 4]
        categorical_options = {
            0: [1, 0],
            2: [1, 0],
            3: [1, 0],
            6: [1, 0],
            7: [1, 0],
        }
    
    elif dataset == 'german_credit':
        categorical_indices = [2, 3, 4]
        numerical_indices = [0, 1]
        pos_numeric = []
        categorical_options = {
            2 : [2, 3, 4, 5],
            3 : [6, 7, 8, 9, 10],
            4 : [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
        }
            
    # Create copies of the datapoints to perturb
    perturbed_datapoint = deepcopy(datapoint)
    num_features_to_perturb = 1
        
    if feature in categorical_indices:
        if dataset == 'german_credit':
            for i in categorical_options[feature]:
                if perturbed_datapoint[i].item() == 1:
                    current_value = i
            possible_values = [v for v in categorical_options[feature] if v != current_value]
            new_value = random.choice(possible_values)
            perturbed_datapoint[current_value] = 0
            perturbed_datapoint[new_value] = 1
            
        else:
            current_value = perturbed_datapoint[feature].item()
            possible_values = [v for v in categorical_options[feature] if v != current_value]
            perturbed_datapoint[feature] = random.choice(possible_values)

    elif feature in numerical_indices:
       
        original_value = perturbed_datapoint[feature].item()
        
        pert_values = [-round(std[feature]), -round(std[feature]/2), round(std[feature]/2), round(std[feature])]
        pert_idx = random.randint(0, 3)
        if feature in pos_numeric:
            pert_idx = 2 + (pert_idx % 2)  # map 0,1,2,3 to 2,3,2,3 so feature only perturbs up
        perturbation = pert_values[pert_idx]
                                        
        # Make sure they don't go negative
        final_value = original_value + perturbation
        if final_value < 0.:
            final_value = 1.
                                    
        perturbed_datapoint[feature] = final_value

    return datapoint, perturbed_datapoint

    
    
    
        