import numpy as np
from scipy import stats
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score


def _compute_p_value(X, y_hat, scores, slope):
    # Calculating the residuals
    residuals = scores - y_hat

    # Standard deviation of the residuals
    std_residuals = np.std(residuals, ddof=2)

    # Standard error of the slope
    stderr_slope = std_residuals / np.sqrt(np.sum((X - np.mean(X)) ** 2))

    # t-statistic for the slope with null hypothesis slope = 0
    t_statistic = slope / stderr_slope

    # Degrees of freedom
    df = X.shape[0] - 2

    # Two-tailed p-value for t-statistic
    p_value = 2 * (1 - stats.t.cdf(np.abs(t_statistic), df))

    return p_value


def cross_validation_splits(X: list, Y: list, k=1):
    if len(X) != len(Y):
        raise ValueError("Length of X and Y must be the same.")

    if k < 1 or k > len(X):
        raise ValueError("Invalid value for k.")

    n = len(X)
    splits = []

    # Generate the indices for the original data
    indices = list(range(n))

    # Iterate through the data to create k-fold splits
    for i in range(0, n, k):
        # Split the data into validation and training data for this fold
        val_indices = indices[i:i + k]
        train_indices = [idx for idx in indices if idx not in val_indices]

        X_train = np.array([X[idx] for idx in train_indices])
        Y_train = np.array([Y[idx] for idx in train_indices])
        X_val = np.array([X[idx] for idx in val_indices])
        Y_val = np.array([Y[idx] for idx in val_indices])

        splits.append((X_train, Y_train, X_val, Y_val))

    return splits


def regress_channel_correlation_score(task_correlation_score: dict):
    task_regression_result = {}
    for task_name, task_data in task_correlation_score.items():
        rmse_per_channel = {}
        linear_regressor_per_channel = {}
        linear_regression_slope_per_channel = {}
        predictions_per_channel = {}
        r2_per_channel = {}
        correlations_per_channel = {}
        scores_per_channel = {}
        p_value_per_channel = {}
        cross_validation_mean_error_per_channel = {}
        cross_validation_std_error_per_channel = {}

        scores = task_data["score"]

        for channel, correlations in task_data["channels"].items():
            scores_copy = scores.copy()
            correlations_copy = correlations.copy()

            # Remove entries with correlate value as None
            indexes_to_remove = [index for index, correlation in enumerate(correlations_copy) if correlation is None]
            for index in reversed(indexes_to_remove):
                del scores_copy[index]
                del correlations_copy[index]

            # Remove entries with score too high (noise data)
            if task_name == "ping_pong_cooperative":
                indexes_to_remove = [index for index, score in enumerate(scores_copy)
                                     if score > 20]
                for index in reversed(indexes_to_remove):
                    del scores_copy[index]
                    del correlations_copy[index]
                    print(f"{channel} removing ping pong experiment {index} due to score too high")

            lr = LinearRegression()
            X = np.array(correlations_copy)[:, None]

            lr.fit(X, scores_copy)

            y_hat = lr.predict(X)

            slope = lr.coef_[0]

            # cross validation with leave-1-out
            splits = cross_validation_splits(correlations_copy, scores_copy, k=1)

            errors = []
            for X_train, Y_train, X_val, Y_val in splits:
                X_train = X_train[:, None]
                X_val = X_val[:, None]

                lr_cross_validation = LinearRegression()
                lr_cross_validation.fit(X_train, Y_train)

                Y_hat_cross_validation = lr_cross_validation.predict(X_val)
                error = np.fabs(Y_val - Y_hat_cross_validation).item()

                errors.append(error)

            mean_error = np.mean(errors)
            std_error = np.std(errors) / np.sqrt(len(errors))

            predictions_per_channel[channel] = y_hat.tolist()
            linear_regressor_per_channel[channel] = lr
            rmse_per_channel[channel] = np.sqrt(mean_squared_error(scores_copy, y_hat))
            r2_per_channel[channel] = r2_score(scores_copy, y_hat)
            linear_regression_slope_per_channel[channel] = slope
            correlations_per_channel[channel] = correlations_copy
            scores_per_channel[channel] = scores_copy
            p_value_per_channel[channel] = _compute_p_value(X, y_hat, scores_copy, slope)
            cross_validation_mean_error_per_channel[channel] = mean_error
            cross_validation_std_error_per_channel[channel] = std_error

        mean_score = np.mean(scores)
        null_model_r2 = r2_score(scores, [mean_score] * len(scores))

        task_regression_result[task_name] = {
            "predictions_per_channel": predictions_per_channel,
            "linear_regressor_per_channel": linear_regressor_per_channel,
            "rmse_per_channel": rmse_per_channel,
            "r2_per_channel": r2_per_channel,
            "linear_regression_slope_per_channel": linear_regression_slope_per_channel,
            "null_model_r2": null_model_r2,
            "correlations_per_channel": correlations_per_channel,
            "scores_per_channel": scores_per_channel,
            "p_value_per_channel": p_value_per_channel,
            "cross_validation_mean_error_per_channel": cross_validation_mean_error_per_channel,
            "cross_validation_std_error_per_channel": cross_validation_std_error_per_channel
        }

    return task_regression_result
