from typing import List

import numpy as np
import pandas as pd

from carla.models.api import MLModel
from carla.models.pipelining import decode, encode, scale


def constraint_violation(
    mlmodel: MLModel, counterfactuals: pd.DataFrame, factuals: pd.DataFrame
) -> List[List[float]]:
    """
    Counts constraint violation per counterfactual

    Parameters
    ----------
    mlmodel: Black-box-model we want to discover
    counterfactuals: Normalized and encoded counterfactual examples
    factuals: Not normalized and encoded factuals

    Returns
    -------

    """
    # immutables = mlmodel.data.immutables
    # print(immutables)

    # # Decode counterfactuals to compare immutables with not encoded factuals
    # df_decoded_cfs = counterfactuals.copy()
    # print(mlmodel.encoder)
    # print(mlmodel.data.categoricals)
    # print(df_decoded_cfs)
    # df_decoded_cfs = decode(mlmodel.encoder, mlmodel.data.categoricals, df_decoded_cfs)
    # df_decoded_cfs[mlmodel.data.continous] = mlmodel.scaler.inverse_transform(
    #     df_decoded_cfs[mlmodel.data.continous]
    # )
    # df_decoded_cfs[mlmodel.data.continous] = df_decoded_cfs[
    #     mlmodel.data.continous
    # ].astype(
    #     "int64"
    # )  # avoid precision error

    # df_decoded_cfs = df_decoded_cfs[immutables]
    # df_factuals = factuals[immutables]
    
    # We want to encode factuals rather than decode counterfactuals (bc there's no guarantee CFEs respect OHE relationships)
    # First, encode the factual's OHE features
    #encoded_factuals = encode(mlmodel.encoder, mlmodel.encoder.feature_names_in_, factuals)
    # Also, scale continuous features
    # Nevermind, just do pipeline in one go
    scaled_encoded_factuals = mlmodel.perform_pipeline(factuals.drop(mlmodel.data.target, axis=1))
    #scaled_encoded_factuals = scale(mlmodel.scaler, mlmodel.data.continous, encoded_factuals)
    # scaled_encoded_factuals = 
    # Then, we need to identify immutable features in the transformed space
    # This is a combo of looking at which immutables are in the final dataset, and which are not
    non_transformed = [feature for feature in mlmodel.data.immutables if feature in scaled_encoded_factuals.columns]
    transformed = []
    # For features that are encoded, 
    # we assume any OHE feature that starts with the original feature name is associated with that feature
    # (this is a weak assumption, but one which would require a deep refactor)
    for feature in list(set(mlmodel.data.immutables) - set(non_transformed)):
        for encoded in mlmodel.encoder.get_feature_names_out():
            if encoded.startswith(feature):
                transformed.append(encoded)
    # Assemble immutables
    immutables = non_transformed + transformed
    # Check for constraint violations
    # logical = scaled_encoded_factuals[immutables] != counterfactuals[immutables]
    cstr_vals = []
    for i, row in counterfactuals[immutables].iterrows():
        if sum(pd.isnull(row)) > 0:
            continue
        row_comps = np.isclose(scaled_encoded_factuals[immutables].loc[i], row, rtol=.000001)
        num_vios = len(immutables)-sum(row_comps)
        cstr_vals.append(num_vios)
    # logical = np.isclose(scaled_encoded_factuals[immutables], counterfactuals[immutables], rtol=.0001)
    # print(logical)
    # logical = (np.sum(logical.values, axis=1).reshape((-1, 1)) - len(immutables))*-1
    # print(logical)    
    
    return cstr_vals
