import enum
from nice import NICE
from counterfactual.import_essentials import *
from counterfactual.utils import *
from counterfactual.train import *
from counterfactual.training_module import *
from counterfactual.net import *
from counterfactual.baseline import ExplainerBase
from counterfactual.evaluate import compute_insensitivity, accuracy, proximity, load_model, evaluate


dummy_config = json.load(open("counterfactual/configs/dummy.json"))
adult_config = json.load(open("counterfactual/configs/adult.json"))
student_config = json.load(open("counterfactual/configs/student.json"))
home_config = json.load(open("counterfactual/configs/home.json"))
breast_config = load_json("counterfactual/configs/extra/breast_cancer.json")
student_performance_config = load_json("counterfactual/configs/extra/student_performance.json")
titanic_config = load_json("counterfactual/configs/extra/titanic.json")

configs = [
    {
        "data_name": "adult", 
        "c_net_path": "saved_weights/adult/c_net/epoch=126-step=24256.ckpt", 
        "c_net_epoch": 127,
        "baseline_path": "saved_weights/adult/baseline/epoch=55-step=10695.ckpt",
        "baseline_epoch": 56,
        "config": adult_config
    },
    {
        "data_name": "student", 
        "c_net_path": "saved_weights/student/c_net/epoch=115-step=22155.ckpt", 
        "c_net_epoch": 116,
        "baseline_path": "saved_weights/student/baseline/epoch=98-step=18908.ckpt",
        "baseline_epoch": 99,
        "config": student_config
    },
    {
        "data_name": "home", 
        "c_net_path": "saved_weights/home/c_net/epoch=564-step=35029.ckpt", 
        "c_net_epoch": 565,
        "baseline_path": "saved_weights/home/baseline/epoch=92-step=5765.ckpt",
        "baseline_epoch": 93,
        "config": home_config
    },
    {
        "data_name": "breast", 
        "c_net_path": "saved_weights/extra/breast/c_net/epoch=440-step=1763.ckpt", 
        "c_net_epoch": 441,
        "baseline_path": "saved_weights/extra/breast/baseline/epoch=383-step=1535.ckpt",
        "baseline_epoch": 384,
        "config": breast_config
    },
    {
        "data_name": "student_performance", 
        "c_net_path": "saved_weights/extra/student_performance/c_net/epoch=451-step=1807.ckpt", 
        "c_net_epoch": 452,
        "baseline_path": "saved_weights/extra/student_performance/baseline/epoch=287-step=1151.ckpt",
        "baseline_epoch": 288,
        "config": student_performance_config
    },
    {
        "data_name": "titanic", 
        "c_net_path": "saved_weights/extra/titanic/c_net/epoch=61-step=371.ckpt", 
        "c_net_epoch": 62,
        "baseline_path": "saved_weights/extra/titanic/baseline/epoch=63-step=383.ckpt",
        "baseline_epoch": 64,
        "config": titanic_config
    },
]


def cf_gen_parallel(
    model: BaselineTrainingModule, 
    is_parallel: bool = True, 
    test_size: int = None
):
    def pred_fn(x: np.ndarray):
        assert isinstance(x, np.ndarray)
        x = pd.DataFrame(x, columns=data.columns[:-1])
        x_transformed = model.transform(x, return_tensor=True)
        y_pred = model.predict(x_transformed)
        return np.stack([
            (1. - y_pred).cpu().detach().numpy().reshape(-1),
            y_pred.cpu().detach().numpy().reshape(-1)
        ], axis=-1)
    
    cat_idx = len(model.continous_cols)
    data = model.data
    continous_cols = model.continous_cols
    discret_cols = model.discret_cols
    ohe = model.encoder
    normalizer = model.normalizer

    cat_idx_list = []
    for cat_col in discret_cols:
        for i, col_name in enumerate(data.columns):
            if cat_col == col_name:
                cat_idx_list.append(i)
    X, y = data[data.columns[:-1]].values, data[data.columns[-1]].values
    train_X, test_X, train_y, test_y = train_test_split(X, y, shuffle=False)

    print(pred_fn(train_X).shape)

    nice = NICE(
        predict_fn=pred_fn, 
        X_train=train_X, y_train=train_y, 
        cat_feat=cat_idx_list, optimization='proximity')
    # print(nice.data.candidates_mask.shape)
    
    # _, label = val_dataset[:]
    length = len(test_X) if test_size is None else test_size

    # cf_algo = torch.rand((length, val_dataset[0][0].size(-1)))
    test_X = test_X[:length]
    cfs = np.empty_like(test_X)

    # print(f"y-axis: {val_dataset[0][0].size(-1)}")

    def gen_step(ix):
        cf = nice.explain(test_X[ix:ix+1, :])
        cfs[ix, :] = cf

    result = []
    # run generate cf in parallel
    if is_parallel:
        Parallel(n_jobs=-1, max_nbytes=None, verbose=False)(
            delayed(gen_step)(
                ix=ix
            )
            for ix, _ in enumerate(tqdm(test_X)) if ix < length
        )
        test_length = 100
        start = time.time()
        for ix, _ in enumerate(tqdm(test_X)):
            if ix < test_length:
                gen_step(ix)
    else:
        test_length = length
        start = time.time()
        for ix, _ in enumerate(tqdm(test_X)):
            if ix < length:
                gen_step(ix)
                # result.append((x, _cf_algo))

    total_time = time.time() - start
    average_time = total_time / test_length
  

    test_X = pd.DataFrame(test_X, columns=data.columns[:-1])
    cfs = pd.DataFrame(cfs, columns=data.columns[:-1])

    X = model.transform(test_X, )
    cfs = model.transform(cfs, )
    
    # validity metrics
    y_prime = 1. - model.predict(X)
    cf_y_algo = model.predict(cfs)
    test_y = torch.from_numpy(test_y[:length])
    # cf_y_model = model.predict(cf_model)

    # robustness
    insensitivity_result = compute_insensitivity(
        X, cfs, pred_fn=model.predict,
        cat_idx=cat_idx, threshold=2., scaler=model.normalizer
    )

    return_res = {
        "x": X,
        "cf": cfs,
        "cf_hat": cfs,
        "y_prime": y_prime,
        "cf_y": cf_y_algo,
        "cf_y_hat": cf_y_algo,
        "diffs": -1,
        "total_num": -1,
        "cat_idx": cat_idx,
        # "robustness": algo_robustness,
        "total_time": total_time,
        "average_time": average_time,
        "pred_accuracy": accuracy(test_y[:length], model.predict(X)),
        "validity": accuracy(y_prime, cf_y_algo).item(),
        "proximity": proximity(X, cfs).item()
    }
    return_res.update(insensitivity_result)
    return return_res


if __name__ == "__main__":
    for config in configs:
        model = load_model(config['baseline_path'], config['baseline_epoch'])
        # valiniaCF
        result = cf_gen_parallel(
            model=model, is_parallel=False, 
            # test_size=100
        )
        result["cat_idx"] = len(model.continous_cols)
        evaluate(result, dataset_name=config['data_name'], cf_name="NICE")

