import numpy as np
import pandas as pd

from carla.models.api import MLModel
from carla.recourse_methods.api import RecourseMethod
from carla.recourse_methods.catalog.user_preferred.library import (
    user_preferred_steps,
)
from carla.recourse_methods.processing import (
    check_counterfactuals,
    encode_feature_names,
)


class UserPreferred(RecourseMethod):
    """
    Implementation 

    """

    def __init__(self, mlmodel: MLModel, hyperparams=None, actionable_set=None, step_size=0.001) -> None:
        super().__init__(mlmodel)

        self._immutables = encode_feature_names(
            self._mlmodel.data.immutables, self._mlmodel.feature_input_order
        )
        self._mutables = [
            feature
            for feature in self._mlmodel.feature_input_order
            if feature not in self._immutables
        ]
        
        assert actionable_set is not None, "A list of actionable features needs to be passed to U.P."
        
        self._actionable_features = [
            # 1 if feature not in self._immutables else 0
            1 if feature in actionable_set
              else 0
            for feature in self._mlmodel.feature_input_order
        ]
        
        self._continuous = self._mlmodel.data.continous
        self._categoricals_enc = encode_feature_names(
            self._mlmodel.data.categoricals, self._mlmodel.feature_input_order
        )

        self._pref_vec = np.array([0.1] * len(self._mlmodel.feature_input_order))
        self._pref_vec[3] = 0.7
        self._pref_vec[4] = 0.1
        
        self._stepSize_matrix = np.array([step_size if feature in self._continuous else 1.0 for feature in self._mlmodel.feature_input_order])
                                 
        self._lower_bound = np.array([0] * len(self._mlmodel.feature_input_order))
        self._upper_bound = np.array([1] * len(self._mlmodel.feature_input_order))
        # self._actionable_features = self._mutables
        self._binary_cat_features = True #checked_hyperparams["binary_cat_features"]
        self._steps_cfs = []
        self._cost_cfs = []
        self._candidates = []

    def get_counterfactuals(self, factuals: pd.DataFrame) -> pd.DataFrame:
        
        # Normalize and encode data
        df_enc_norm_fact = self.encode_normalize_order_factuals(factuals)
        y_target = +1
        list_cfs = []
        list_fs = []
        # steps_cfs = []
        encoded_feature_names = self._mlmodel.encoder.get_feature_names_out(
            self._mlmodel.data.categoricals
        )
        
        cat_features_indices = [
            df_enc_norm_fact.columns.get_loc(feature)
            for feature in encoded_feature_names
        ]
        
        for index, row in df_enc_norm_fact.iterrows():
            # print('row-instance', np.array([row.to_numpy()])[0])
            list_fs.append(list(np.array([row.to_numpy()])[0]))
            counterfactual, steps, cost, recourse_candidates = user_preferred_steps(
                row,
                y_target,
                self._mutables,
                self._immutables,
                self._continuous,
                self._categoricals_enc,
                self._mlmodel.feature_input_order,
                self._mlmodel.raw_model,
                self._pref_vec, 
                self._stepSize_matrix,
                self._actionable_features,
                self._lower_bound,
                self._upper_bound,
                cat_features_indices,
                self._binary_cat_features
            )
            list_cfs.append(list(counterfactual))
            
            self._steps_cfs.append(list(steps))
            self._cost_cfs.append(list(cost))
            self._candidates.append(recourse_candidates)
            
        df_cfs = check_counterfactuals(self._mlmodel, list_cfs)
        return df_cfs
        
    
    def get_factuals(self, factuals: pd.DataFrame) -> pd.DataFrame:
        
        df_enc_norm_fact = self.encode_normalize_order_factuals(factuals)
        list_fs = []
        
        encoded_feature_names = self._mlmodel.encoder.get_feature_names_out(
            self._mlmodel.data.categoricals
        )
        
        cat_features_indices = [
            df_enc_norm_fact.columns.get_loc(feature)
            for feature in encoded_feature_names
        ]

        for index, row in df_enc_norm_fact.iterrows():
            list_fs.append(list(np.array([row.to_numpy()])[0]))
            
        return list_fs
        
    
    def get_steps(self):
        
        return self._steps_cfs

    def get_cost(self):
        
        return self._cost_cfs
        
    def get_candidates(self):
        
        return self._candidates
        
    
        
    
