import numpy as np
from sklearn.linear_model import LinearRegression, Ridge
from tqdm import tqdm
from sklearn.metrics import r2_score, mean_squared_error
from time_series_influences.utils import split_time_series


def block_removal_experiment(value_dict, X, y, X_test, y_test,predictor='lr'):
    removal_ascending_dict, removal_descending_dict=dict(), dict()
    for key in value_dict.keys():
        removal_ascending_dict[key]=block_removal_core(X, y, X_test, y_test, value_dict[key], ascending=True, predictor=predictor)
        removal_descending_dict[key]=block_removal_core(X, y, X_test, y_test, value_dict[key], ascending=False, predictor=predictor)
    return {'ascending':removal_ascending_dict, 'descending':removal_descending_dict}

def block_removal_core(X, y, X_test, y_test, value_list, ascending=True, predictor='lr'):
    n_sample=len(X)
    if ascending is True:
        sorted_value_list=np.argsort(value_list) # ascending order. low to high.
    else:
        sorted_value_list=np.argsort(value_list)[::-1] # descending order. high to low.
    
    accuracy_dict={'r2':[],'mse':[]}
    n_period = min(n_sample//200, 5) 
    for percentile in tqdm(range(0, 100, n_period)):
        '''
        We repeatedly remove 5% of entire data points at each step.
        The data points whose value belongs to the lowest group are removed first.
        The larger, the better
        '''
        start_index = int(n_sample * percentile / 100)
        sorted_value_list_tmp=sorted_value_list[start_index:]
        if predictor == 'lr':
            try:
                model=LinearRegression()
                model.fit(X[sorted_value_list_tmp], y[sorted_value_list_tmp])
                y_pred = model.predict(X_test)
                mse = mean_squared_error(y_test, y_pred) 
                r2 = r2_score(y_test, y_pred)
            except:
                mse,r2 = 0,0
        
        accuracy_dict['mse'].append(mse)
        accuracy_dict['r2'].append(r2)

    return accuracy_dict
