from torch import tensor, float32, unique
from pandas import DataFrame, concat, get_dummies


def parse_recourse_df(recourse_df, device):
    states = feature_tensor(recourse_df["x"], device=device)
    next_states = feature_tensor(recourse_df["x_recourse"], device=device)
    actions = next_states - states
    ep_nums = tensor(recourse_df.index, device=device)
    return states, actions, next_states, ep_nums


def parse_recourse_df_gc(recourse_df, device):
    from .features.german_credit import feature_order

    n = len(recourse_df)
    df = concat(
        [
            DataFrame(list(recourse_df["x"])),
            DataFrame(list(recourse_df["x_recourse"])),
        ]
    )
    assert len(df) == n * 2
    df = df.rename(
        columns={
            0: "status",
            1: "duration",
            2: "creditHistory",
            3: "purpose",
            4: "amount",
        }
    )
    df = get_dummies(df, dtype=float)
    assert (df.columns == feature_order).all()
    states = tensor(df.iloc[:n].values, dtype=float32, device=device)
    next_states = tensor(df.iloc[n:].values, dtype=float32, device=device)
    actions = next_states - states
    ep_nums = tensor(recourse_df.index, device=device)
    return states, actions, next_states, ep_nums


def feature_tensor(feature_lists, device):
    return tensor(
        [
            [0 if xi == "no" else (1 if xi == "yes" else xi) for xi in x]
            for x in feature_lists
        ],
        dtype=float32,
        device=device,
    )


def make_split_features_and_thresholds(states, actions, next_states, features):
    assert (
        len(features) / 3 == states.shape[1] == actions.shape[1] == next_states.shape[1]
    )
    split_features_and_thresholds = {}
    for f in features:
        unique_values = unique(f(states, actions, next_states))
        thresholds = (unique_values[:-1] + unique_values[1:]) / 2.0
        split_features_and_thresholds[f] = thresholds
    return split_features_and_thresholds
