# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05b_evaluate.ipynb (unless otherwise specified).

__all__ = ['pl_logger', 'load_model', 'load_model_jupyter', 'proximity', 'sparsity', 'compute_manifold_dist',
           'compute_insensitivity', 'cf_gen_parallel', 'model_cf_gen', 'evaluate', 'test_evaluate']

# Cell
from numpy.lib.utils import deprecate
from .import_essentials import *
from .utils import *
from .train import *
from .training_module import *
from .net import *
from .baseline import ExplainerBase

from pytorch_lightning.metrics.functional.classification import accuracy
from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import NearestNeighbors
# imports from captum library
from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation

plt.style.use(['science'])
pl_logger = logging.getLogger('lightning')

# Cell
def load_model(checkpoint_path: str, n_iter: Optional[int] = None, module=BaselineModel, t_configs={'gpus': 0}):
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f'{checkpoint_path} is not found.')

    # automatically infer n_iter by assuming
    # checkpoint_path = f"{dict_path}/epoch={n_epoch}-step={step}.ckpt"
    if n_iter is None:
        n_iter = int(checkpoint_path.split("-")[0].split("=")[-1]) + 1
    model = module.load_from_checkpoint(checkpoint_path)
    trainer = pl.Trainer(
        max_epochs=n_iter, resume_from_checkpoint=checkpoint_path, num_sanity_val_steps=0,
        logger=False, checkpoint_callback=False, **t_configs)
    trainer.fit(model)
    return model


def load_model_jupyter(checkpoint_path: str, n_iter: int, module=BaselineModel, m_configs={}, t_configs={'gpus': 0}):
    model = module.load_from_checkpoint(checkpoint_path, configs=m_configs)
    trainer = pl.Trainer(
        max_epochs=n_iter, resume_from_checkpoint=checkpoint_path, num_sanity_val_steps=0, **t_configs)
    trainer.fit(model)
    return model


def proximity(x, c):
    return torch.abs(x - c).sum(dim=1).mean()

def sparsity(x: torch.Tensor, cf: torch.Tensor, cat_idx: int):
    cat_sparsity = proximity(x[:, cat_idx:], cf[:, cat_idx:]) / 2
    cont_sparsity = torch.norm(
        (x[:, :cat_idx] - cf[:, :cat_idx]).abs(), p=0, dim=1).mean()
    return cont_sparsity + cat_sparsity

def compute_manifold_dist(X, cfs):
    knn = NearestNeighbors()
    knn.fit(X)
    nearest_dist, nearest_points = knn.kneighbors(cfs, 1, return_distance=True)
    return np.mean(nearest_dist).item()


def compute_insensitivity(
    x: torch.Tensor, cf: torch.Tensor, pred_fn: Callable,
    cat_idx: int, threshold: float, scaler: MinMaxScaler):
    # get normalized threshold = threshold / (max - min)
    data_range = scaler.data_range_
    thredshold_normed = threshold / torch.from_numpy(data_range)
    # select continous features
    x_cont = x[:, :cat_idx]
    cf_cont = cf[:, :cat_idx]
    # calculate the diff between x and c
    cont_diff = torch.abs(x_cont - cf_cont) < thredshold_normed
    # total nums of differences
    total_diffs = torch.sum(cont_diff.any(dim=1))
    # new cf
    cf_cont_hat = torch.where(cont_diff, x_cont, cf_cont)
    cf_hat = deepcopy(cf)
    cf_hat[:, :cat_idx] = cf_cont_hat
    # cf_y
    cf_y = pred_fn(cf)
    cf_y_hat = pred_fn(cf_hat)
    return {
        "diffs": ((cf_y > 0.5) != (cf_y_hat > 0.5)).sum(),
        "total_num": total_diffs,
        "cf_hat": cf_hat,
        "cf_y_hat": cf_y_hat
    }

# Cell
def cf_gen_parallel(cf_params: dict, CFExplainer: ExplainerBase, is_parallel: bool = True, test_size: int = None) -> Dict:
    """generate CF in parallel

    Args:
        model (BaselineModel): black-box model to be explained
        CFExplainer (ExplainerBase): cf algo to explain the model

    Returns:
        Dict: results
    """
    model = cf_params['model']
    cat_idx = len(model.continous_cols)

    val_dataset = model.val_dataset
    # X, label = val_dataset[:]
    _, label = val_dataset[:]
    length = len(val_dataset) if test_size is None else test_size
    # length = 1

    # cf_algo = torch.rand((length, val_dataset[0][0].size(-1)))
    X = torch.empty((length, val_dataset[0][0].size(-1)))
    cfs = torch.empty_like(X)

    print(f"y-axis: {val_dataset[0][0].size(-1)}")

    def gen_step(ix, x, y):
        x = x.reshape(1, -1)
        cf_exp = CFExplainer(x, **cf_params)
        # generate counterfactual explanation for algo and model
        cf = cf_exp.generate_cf(1000)
        X[ix, :] = x.detach()
        cfs[ix, :] = cf.detach()
        # return x, cf.detach()

    # run generate cf in parallel
    if is_parallel:
        Parallel(n_jobs=-1, max_nbytes=None, verbose=False)(
            delayed(gen_step)(
                ix=ix,
                x=x,
                y=y
            )
            for ix, (x, y) in enumerate(tqdm(val_dataset)) if ix < length
        )
        test_length = 100
        start = time.time()
        for ix, (x, y) in enumerate(tqdm(val_dataset)):
            if ix < test_length:
                gen_step(ix, x, y)
    else:
        start = time.time()
        for ix, (x, y) in enumerate(tqdm(val_dataset)):
            if ix < length:
                gen_step(ix, x, y)
                # result.append((x, _cf_algo))

    total_time = time.time() - start
    average_time = total_time / test_length

    # for ix, (x, _cf_algo) in enumerate(result):
    #     X[ix, :] = x
    #     cf_algo[ix, :] = _cf_algo

    # validity metrics
    y_prime = torch.ones((length)) - model.predict(X)
    cf_y_algo = model.predict(cfs)
    print('y_prime.shape: ', y_prime.shape)
    print('cf_y.shape: ', cf_y_algo.shape)
    # cf_y_model = model.predict(cf_model)

    # robustness
    # diffs, total_num = model.check_cont_robustness(X, cf_algo, cf_y_algo)
    # algo_robustness = 1 - torch.true_divide(diffs, total_num) if total_num != 0 else 0.
    insensitivity_result = compute_insensitivity(
        X, cfs, pred_fn=model.predict,
        cat_idx=cat_idx, threshold=2., scaler=model.normalizer
    )

    return_res = {
        "x": X,
        "cf": cfs,
        "cf_hat": cfs,
        "y_prime": y_prime,
        "cf_y": cf_y_algo,
        "cf_y_hat": cf_y_algo,
        "diffs": -1,
        "total_num": -1,
        "cat_idx": cat_idx,
        # "robustness": algo_robustness,
        "total_time": total_time,
        "average_time": average_time,
        "pred_accuracy": accuracy(label, model.predict(X)),
        "validity": accuracy(y_prime, cf_y_algo).item(),
        "proximity": proximity(X, cfs).item()
    }
    return_res.update(insensitivity_result)
    return return_res


def model_cf_gen(model: CounterfactualTrainingModule, is_parallel: bool = False) -> Dict:
    result = []

    val_dataset = model.val_dataset
    cat_idx = len(model.continous_cols)
    val_X, label = val_dataset[:]
    X = torch.rand(val_X.size())
    cf_algo = torch.rand(val_X.size())

    def gen_step(x):
        x = x.reshape(1, -1)
        _cf = model.generate_cf(x)
        return x, _cf


    if is_parallel:
        result = Parallel(n_jobs=-1, max_nbytes=None, verbose=False)(
            delayed(gen_step)(x=x)
            for ix, (x, y) in enumerate(tqdm(val_dataset))
        )
        length = int(len(val_dataset) / 10)
        start = time.time()
        for ix, (x, y) in enumerate(tqdm(val_dataset)):
            if ix < length:
                gen_step(x)
    else:
        length = len(val_dataset)
        start = time.time()
        for ix, (x, y) in enumerate(tqdm(val_dataset)):
            x, _cf_algo = gen_step(x)
            result.append((x, _cf_algo))

    total_time = time.time() - start
    average_time = total_time / length
    print(f"total time: {total_time}; length: {length}; average time: {average_time} ")

    for ix, (x, _cf_algo) in enumerate(result):
        X[ix, :] = x
        cf_algo[ix, :] = _cf_algo

    # validity metrics
    y_prime = torch.ones(y.size()) - model.predict(X)
    cf_y = model.predict(cf_algo)
    print(f"X: {X.size()}; cf_y: {cf_y.size()}; cf_algo: {cf_algo.size()}")
    # print(f"avg steps: {np.average(count_list)} (std={np.std(count_list)})")
    # robustness
    # diffs, total_num = model.check_cont_robustness(X, cf_algo, cf_y)
    # algo_robustness = 1 - torch.true_divide(diffs, total_num) if total_num != 0 else 0.

    insensitivity_result = compute_insensitivity(
        X, cf_algo, pred_fn=model.predict,
        cat_idx=cat_idx, threshold=2., scaler=model.normalizer
    )

    return_res = {
        "x": X,
        "cf": cf_algo,
        "cf_hat": cf_algo,
        "cf_y_hat": cf_y,
        "y_prime": y_prime,
        "cf_y": cf_y,
        "diffs": -1,
        "total_num": -1,
        "cat_idx": cat_idx,
        # "robustness": algo_robustness,
        "total_time": total_time,
        "average_time": average_time,
        "pred_accuracy": accuracy(label[:length], model.predict(X)),
        "validity": accuracy(y_prime, cf_y).item(),
        "proximity": proximity(X, cf_algo).item()
    }
    return_res.update(insensitivity_result)
    return return_res

# Cell
def evaluate(result: Dict, dataset_name: str, cf_name: str, is_logging: bool = True, seed: int = None):
    """calculate metrics of CF algos and log the results

    Args:
        result (Dict): results generated from `cf_gen_parallel`
            - x: input instance
            - cf: counterfactual examples
            - y_prime: desired label (the filp of predicted label when the problem is binary)
            - cf_y: counterfactual outcomes
        dataset_name (str): dataset name
        cf_name (str): counterfactual algorithm's name

    Raises:
        ValueError: dataset name is invalid
        ValueError: cf_name is invalid

    Returns:
        Dict: final result
    """
    x = result['x']
    cf = result['cf']
    cf_hat = result['cf_hat']
    y_prime = result['y_prime']
    cf_y = result['cf_y']
    cf_y_hat = result['cf_y_hat']
    cat_idx = result['cat_idx']
    diffs = result['diffs']
    total_num = result['total_num']
    # robustness = result['robustness']
    total_time = result['total_time']
    average_time = result['average_time']
    pred_accuracy = result['pred_accuracy']
    validity = result['validity']
    so_validity = accuracy(y_prime, cf_y_hat).item()
    proximity_score = result['proximity']
    robustness = 1 - torch.true_divide(diffs, total_num) if total_num != 0 else 0.
    sparsity_score = sparsity(x, cf_hat, cat_idx=cat_idx)
    fo_sparsity_score = sparsity(x, cf, cat_idx=cat_idx)
    so_proximity_score = proximity(x, cf_hat)

    if torch.is_tensor(diffs):
        diffs = diffs.item()
    if torch.is_tensor(total_num):
        total_num = total_num.item()
    if torch.is_tensor(robustness):
        robustness = robustness.item()
    if torch.is_tensor(pred_accuracy):
        pred_accuracy = pred_accuracy.item()
    if torch.is_tensor(sparsity_score):
        sparsity_score = sparsity_score.item()
        fo_sparsity_score = fo_sparsity_score.item()
    if torch.is_tensor(so_proximity_score):
        so_proximity_score = so_proximity_score.item()

    dataset_names = ['dummy', 'adult', 'student', 'home']
    extra_dataset_names = ['credit_card', 'german', 'student_performance', 'breast', 'heart', 'titanic']
    cf_names = ['VanillaCF', 'DiverseCF',
                'ProtoCF', 'VAE-CF', 'C-CHVAE', 'CounterfactualNet',
                'CounterfactualNet-NoPass', 'CounterfactualNet-Separate', 'CounterfactualNet-loss2=l1',
                'NICE', 'CounteRGAN'
                ]
    metrics = ['cat_proximity', 'cont_proximity', 'validity',
               'robustness', 'sparsity', 'diffs', 'total_num',
               'time', 'pred_accuracy', 'proximity', 'so_validity', 'fo_sparsity', 'so_proximity', 'manifold_dist']

    is_extra = dataset_name in extra_dataset_names
    if is_extra:
        if seed is None:
            csv_path = f"results/extra/{dataset_name}/metrics.csv"
        else:
            csv_path = f"results/extra/{dataset_name}/metrics-{seed}.csv"
    else:
        if seed is None:
            csv_path = f"results/{dataset_name}/metrics.csv"
        else:
            csv_path = f"results/{dataset_name}/metrics-{seed}.csv"

    # if is_extra:
    #     result_path = f"results/extra/{dataset_name}/{cf_name}_result.pt"
    # else:
    #     result_path = f"results/{dataset_name}/{cf_name}_result.pt"


    if (dataset_name not in dataset_names) and (dataset_name not in extra_dataset_names):
        raise ValueError(
            f"dataset_name ({dataset_name}) is not valid; it should be one of {dataset_names + extra_dataset_names}.")

    if cf_name not in cf_names:
        raise ValueError(
            f"cf_name ({cf_name}) is not valid; it should be one of {cf_names}.")

    if os.path.exists(csv_path):
        r = pd.read_csv(csv_path, index_col=0).to_dict()
        for metric in metrics:
            if metric not in r.keys():
                r[metric] = {cf_algo: -1 for cf_algo in cf_names}
    else:
        r = {metric: {cf_algo: -1 for cf_algo in cf_names}
             for metric in metrics}

    r['cont_proximity'][cf_name] = proximity(x[:, :cat_idx], cf[:, :cat_idx]).item()
    r['cat_proximity'][cf_name] = proximity(x[:, cat_idx:], cf[:, cat_idx:]).item()
    r['validity'][cf_name] = validity
    r['robustness'][cf_name] = robustness
    r['diffs'][cf_name] = diffs
    r['total_num'][cf_name] = total_num
    r['time'][cf_name] = average_time
    r['pred_accuracy'][cf_name] = pred_accuracy
    r['proximity'][cf_name] = proximity_score
    r['so_proximity'][cf_name] = so_proximity_score
    r['sparsity'][cf_name] = sparsity_score
    r['fo_sparsity'][cf_name] = fo_sparsity_score
    r['so_validity'][cf_name] = so_validity
    r['manifold_dist'][cf_name] = compute_manifold_dist(x, cf)

    if is_logging:
        pd.DataFrame.from_dict(r).to_csv(csv_path)
        # torch.save(result, result_path)
        print("metrics have been saved")

    final_result = {metric: r[metric][cf_name] for metric in metrics}

    print("Final result:")
    pprint(final_result)

    return final_result

def test_evaluate():
    result = {
        "x": torch.rand((1000, 127)),
        "cf": torch.rand((1000, 127)),
        "y_prime": torch.rand((1000, 1)),
        "cf_y": torch.rand((1000, 1)),
        "diffs": 100,
        "total_num": 100,
        "robustness": 1.0
    }
    result["cat_idx"] = 21
    evaluate(result, dataset_name="student", cf_name="VanillaCF")