# Utils
import os
import torch
import numpy as np

# Recourse Methods
from Recourse_Methods.gradient_methods import SCFE
from Recourse_Methods.gradient_methods import DICE
from Recourse_Methods.manifold_methods import REVISE
from Recourse_Methods.manifold_methods import CEM


from Recourse_Methods.causal_methods import CausalRecourse


def _get_input_subset(model, inputs: torch.tensor,
                      decision_threshold: float = 0.5) -> torch.tensor:
    
    yhat = (model(inputs) > decision_threshold) * 1
    check = (model(inputs) < decision_threshold).detach().numpy()
    selected_indices = np.where(check)[0]
    input_subset = inputs[selected_indices]
    predicted_label_subset = yhat[selected_indices]
    
    return input_subset, predicted_label_subset
    

def get_recourses(loader, train_loader, ml_model, vae_model, dataname, task: str = "regression",
                  step: float = 0.00, lr: float = 1e-3, max_iter: int = 1000, threshold: float = 0.5) -> torch.tensor:
    
    """
    :param train_loader: required for the SCM structural equation models
    :return:
    """
    
    if dataname == 'german':
        recourse_names = ['cr', 'scfe', 'revise']
    else:
        recourse_names = ['scfe', 'revise', 'cem', 'dice']
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    counterfactuals = None

    data_iter = iter(loader)
    inputs, labels, indeces = data_iter.next()
    
    inputs = inputs.to(device).float()
    labels = labels.to(device).type(torch.long)
    inputs, labels = _get_input_subset(ml_model, inputs, decision_threshold=threshold)

    # These methods make the 'IMF assumptions
    scfe = SCFE(ml_model, setting=task, lr=lr,
                max_iter=max_iter, target_thres=threshold)
    
    dice = DICE(ml_model, setting=task, lr=lr,
                max_iter=max_iter, target_thres=threshold)
    
    # These methods make the "Manifold assumption" and use gradients in latent space
    revise = REVISE(classifier=ml_model, model_vae=vae_model, _lambda=0.00, step=0.05, lr=1e-2, max_iter=1000,
                    target_threshold=threshold)
    
    cem = CEM(classifier=ml_model, model_vae=vae_model, _lambda=0.00, step=0.05, lr=1e-2, max_iter=1000,
              target_threshold=threshold)

    recourse_methods = {
        'scfe': scfe,
        'dice': dice,
        'revise': revise,
        'cem': cem,
    }
    
    results_dict = {
        "scfe": 0,
        "revise": 0,
        'dice': 0,
        'cem': 0
    }

    if "cr" in recourse_names:
        cr = CausalRecourse(classifier=ml_model,
                            X_train=train_loader.dataset.X,
                            step_size=-1e-2,
                            lamb=1,
                            delta_max=0.1)
        
        recourse_methods["cr"] = cr
        results_dict["cr"] = 0
        
    for idx, recourse_name in enumerate(recourse_names):
    
        print(f"  Finding recourses using: {recourse_names[idx]}")
    
        counterfactuals = torch.zeros_like(inputs)
        for j in range(labels.shape[0]):
            label = labels[j]
            counterfactual = recourse_methods[recourse_name].generate_counterfactuals(inputs[j].reshape(-1),
                                                           target_class=1-label)
            counterfactuals[j, :] = counterfactual

        results = {
            "recourses": counterfactuals,
            "inputs": inputs,
        }
        
        results_dict[recourse_name] = results
        
    return results_dict
