from multiprocessing.spawn import old_main_modules
import torch
from ML_Models.base_model import BaseMLModel

def find_impactful_data_greedy(model: BaseMLModel, xcf: torch.tensor, Xtrain, ytrain, thres=0.0, mode="cnt", addc = 1e-3):
    """ Greedy algorithm finding the most influential data point for deletion.
        Return index.
        :param model: The current machine learning model.
        :param xcf: [B, imput_dim] batch of counterfactuals generated with the orgiginal model,
            that are supposed to be invalidated.
    """
    # pred = model.predict_with_logits(xcf)
    # Compute invalidation.
    # print("Invalid Counterfactuals: ", pred < thres)
    
    new_params = model.parameter_change_under_removal(Xtrain, ytrain) + model.get_all_params()
    new_predictions = model.predict_from_parameters(xcf, new_params)
    print(new_predictions.shape)
    if mode == "cnt":
        value, idx = torch.max(torch.sum(new_predictions < thres, dim=0), dim=0)
    elif mode == "maxcng":
        old_prediction = model.predict_with_logits(xcf).reshape(-1,1)
        value, idx = torch.max(torch.mean(old_prediction-new_predictions, dim=0), dim=0)
    elif mode == "maxinval":
        value, idx = torch.max(torch.sum(torch.relu(-(new_predictions-thres)+addc), dim=0), dim=0)
    print(value, idx)
    return idx

def select_random_point(model: BaseMLModel, xcf: torch.tensor, Xtrain, ytrain, thres, **params):
    """ Random removal strategy """
    return torch.randint(0, len(Xtrain), (1,))

def compute_invalidation_curve(model_full, xcf: torch.tensor, Xtrain, ytrain, strategy, n_points_removed = 10, thres=0.0, params=None):
    invalidation_rate = torch.zeros(n_points_removed+1)
    data_weights = torch.ones(len(Xtrain), dtype=torch.long)
    ## test 
    #data_weights[1000:] = 0
    deleted_indices = torch.zeros(n_points_removed)
    for i in range(n_points_removed):
        #print("Number of data points in model:", torch.sum((data_weights==1).long()))
        model_full.reset_data_weights_and_size(torch.sum((data_weights==1).long()))
        #with torch.no_grad():
        model_full.fit(Xtrain[data_weights==1], ytrain[data_weights==1])
        with torch.no_grad():
            pred = model_full.predict_with_logits(xcf)
        invalidation_rate[i] = torch.mean((pred < thres).float())
        print(f"Invalidation rate ({i}) ", invalidation_rate[i].item())
        print(f"Computing {i+1}th removal...")
        org_ind = torch.nonzero(data_weights==1).flatten() # indices of the passed non-complete dataset.
        # Fit the new model.
        print("Fitting model...")
        if params is not None:
            idx = strategy(model_full, xcf, Xtrain[data_weights==1], ytrain[data_weights==1], thres, **params)
        else:
            idx = strategy(model_full, xcf, Xtrain[data_weights==1], ytrain[data_weights==1], thres)
        print(f"Deletion of point {org_ind[idx]}.")
        data_weights[org_ind[idx]] = 0
        deleted_indices[i] = org_ind[idx]

    model_full.reset_data_weights_and_size(torch.sum((data_weights==1).long()))
    model_full.fit(Xtrain[data_weights==1], ytrain[data_weights==1])
    pred = model_full.predict_with_logits(xcf)
    invalidation_rate[n_points_removed] = torch.mean((pred < thres).float())
    print(f"Invalidation rate ({n_points_removed}) ", invalidation_rate[n_points_removed].item())    
    return invalidation_rate, deleted_indices
