import numpy as np
import warnings

from utils.utils import score, split_indices_and_prep_dataset, save_cv_results
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import GridSearchCV, PredefinedSplit
from sklearn.exceptions import ConvergenceWarning

from models.lin_reg import LogReg

# Suppress ConvergenceWarning globally
warnings.filterwarnings("ignore", category=ConvergenceWarning)

@ignore_warnings(category=ConvergenceWarning)
def linear_eval_cv(
        cfg,
        subjects,
        dataset,
        test_dataset,
        n_train,
        n_val,
        n_test,
        setting,
        world_size,
        n_folds,
        fold,
        ncv_i):

    _, _, _, dataset, test_dataset, sub_ids = split_indices_and_prep_dataset(
        cfg, subjects, dataset, test_dataset, n_train, n_val, n_test, setting, world_size, n_folds, fold, ncv_i)

    # data and label prep
    y = dataset.labels[:] # n_features, n_labels
    
    multitarget = (y.shape[1] > 1)
    if multitarget:
        model, param_grid = OVR_LogReg
    else:
        if cfg["model"]["n_classes"] == 1:
           model, param_grid = LinReg 
        else:
            model, param_grid = LogReg
    y = y.squeeze()

    averaging = False

    X_train = dataset.features[dataset.train_epochs].astype(np.float32)
    X_val = dataset.features[dataset.val_epochs].astype(np.float32)
    X_train = np.where(np.isinf(X_train), np.nan, X_train) # map INFs to NANs for imputer
    X_val = np.where(np.isinf(X_val), np.nan, X_val)
    
    X_train = X_train.reshape(X_train.shape[0], -1)
    X_val = X_val.reshape(X_val.shape[0], -1)
        
    y_train = y[dataset.train_epochs]
    y_val = y[dataset.val_epochs]

    if test_dataset: # Test data and labels from *test_dataset*
        y = test_dataset.labels[:]
        X = test_dataset.features[:].astype(np.float32)
        X = np.where(np.isinf(X), np.nan, X)
        
        X = X.reshape(X.shape[0], -1)
        X_test = X[test_dataset.test_epochs]
        y_test = y[test_dataset.test_epochs]

        to_del = determine_invalid_data(y_test)

        X_test = np.delete(X_test, to_del, axis=0)
        y_test = np.delete(y_test, to_del, axis=0).squeeze()
        test_ids = np.delete(sub_ids["test"], to_del, axis=0)
    else:
        X_test = dataset.features[dataset.test_epochs].astype(np.float32)
        X_test = np.where(np.isinf(X_test), np.nan, X_test)

        y_test = y[dataset.test_epochs]
        test_ids = sub_ids["test"]

    # Do train/evaluation split and grid-search
    fold_indices = np.concatenate((np.ones(len(X_train)), np.zeros(len(X_val)))) 
    cv_setup = PredefinedSplit(test_fold=fold_indices)
    gs = GridSearchCV(model, param_grid, cv=cv_setup, refit=False, n_jobs=cfg["training"]["num_workers"])

    gs.fit(np.concatenate((X_train, X_val)), np.concatenate((y_train, y_val)))

    allscores=gs.cv_results_['mean_test_score']
    print(allscores, flush=True)

    # Fetch results and find best parameters
    results = gs.cv_results_
    best_index = results['rank_test_score'].argmin()
    best_params = results['params'][best_index]

    if multitarget: # Keep also the 'estimator__' prefix (e.g. estimator__C)
        updated_params = {}
        for k,v in best_params.items():
            name = k.split('__')[-2:]
            name = name[0] + '__' + name[1]
            updated_params[name] = v
        best_params = updated_params
    else:
        best_params = {k.split('__')[-1]: v for k, v in best_params.items()}

    if "probability" in model[-1].get_params():
        model[-1]["probability"] = True
    model[-1].set_params(**best_params)
    if n_train < 500:
        model.fit(X_train, y_train)
    else:
        model.fit(np.concatenate((X_train, X_val)), np.concatenate((y_train, y_val)))

    try:
        if multitarget:
            n_iter = model.steps[-1][1].estimators_[0].n_iter_
        else:
            n_iter = model.steps[-1][1].n_iter_
        print("Number of iters:", n_iter, flush=True)
    except:
        n_iter = 0 # Catch closed-form estimators.
    best_params["n_iter"] = n_iter

    # Predict the test-set given the trained model
    try:
        y_pred = model.predict_proba(X_test)[:,1]
    except:
        y_pred = model.predict(X_test)
    y_pred = y_pred.reshape(-1, dataset.labels[:].shape[1])

    # Go from epoch-level prediction to subject-level
    sub_ys_true, sub_ys_pred, metrics = score(y_test, y_pred, test_ids, cfg["model"]["n_classes"], (not averaging),
    logits=False)

    save_cv_results("SSL_LIN", cfg, sub_ys_true, sub_ys_pred, metrics, 0., best_params, n_train, fold, ncv_i)

def determine_invalid_data(labels: np.array) -> np.array:

    # In case of single-label we filter out -999.
    # In case of multi-label we filter out the samples for which no label applies.

    if labels.shape[1] == 1:
        to_del = np.where(labels == -999)[0]
    else:
        to_del = np.where(np.all(labels == 0, axis=1))[0]

    return to_del