import numpy as np
import torch

from imputation_methods.ImputationMethod import ImputationMethod
from models.model_utils import ModelPrediction
from models.qr_models.PredictionIntervalModel import PredictionIntervals
from models.regressors.MeanRegressor import MeanRegressor
from utils.utils import corr, HSIC, get_y

from utils.conditional_independence_metrics import run_all_tests


class RegressorImputation(ImputationMethod):
    def __init__(self, mean_regressor: MeanRegressor):
        super().__init__()
        self.mean_regressor = mean_regressor

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000,
            batch_size=64, n_wait=20,
            **kwargs):
        super().fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val,
                    epochs=epochs, batch_size=batch_size, n_wait=n_wait, **kwargs)
        self.mean_regressor.fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs, batch_size, n_wait,
                                z_train, z_val, **kwargs)

    @property
    def name(self):
        return f"{self.mean_regressor.name}"

    def predict(self, x, z, **kwargs):
        return self.mean_regressor.predict_mean(x, z, **kwargs)

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal):
        super().calibrate(x_cal, y_cal, z_cal, deleted_cal)
        self.mean_regressor.calibrate(x_cal, y_cal, z_cal, deleted_cal)

    def compute_performance(self, x_test, y, z_test, full_y_test, deleted_test,
                            test_calibrated_intervals: ModelPrediction):
        model_prediction = self.mean_regressor.predict_mean(x_test, z_test).squeeze()
        full_y_test = full_y_test.squeeze()
        full_y_test = get_y(full_y_test)
        z_test = z_test.squeeze()
        if len(z_test.shape) == 1:
            z_test = z_test.unsqueeze(-1)
            reduced_z_test = z_test
        else:
            reduced_z_test = z_test.mean(dim=-1)
        errors = full_y_test - model_prediction
        mape = (errors / full_y_test).abs().mean().item() * 100
        mae = errors.abs().mean().item()
        abs_mean_Y = full_y_test.abs().mean().item()
        relative_error = errors.mean().item() / full_y_test.mean().item()
        abs_relative_error = (mae / abs_mean_Y)
        smape = 100 * (2 * errors.abs() / (full_y_test.abs() + model_prediction.abs())).mean().item()
        r2 = 1 - ((errors ** 2).sum() / ((full_y_test - full_y_test.mean()).sum() ** 2)).item()
        mse = (errors ** 2).mean().item()
        rmse = np.sqrt(mse).item()
        rrmse = rmse / full_y_test.mean().item()
        nrmse = rmse / full_y_test.std().item()

        if isinstance(test_calibrated_intervals, PredictionIntervals):
            intervals = test_calibrated_intervals.intervals
            y2_covered = (full_y_test <= intervals[:, 1]) & (full_y_test >= intervals[:, 0])
            y2_error_corr_given_missing_and_covered = corr(full_y_test[deleted_test & y2_covered],
                                                           errors[deleted_test & y2_covered]).item()
            y2_error_corr_given_missing_and_uncovered = corr(full_y_test[deleted_test & ~y2_covered],
                                                             errors[deleted_test & ~y2_covered]).item()
            interval1_conditional_independence_tests = run_all_tests(errors.unsqueeze(1), intervals[:, 0], z_test)
            interval1_conditional_independence_tests = {f"int1_{k}": v for k, v in
                                                        interval1_conditional_independence_tests.items()}

            interval2_conditional_independence_tests = run_all_tests(errors.unsqueeze(1), intervals[:, 1], z_test)

            interval2_conditional_independence_tests = {f"int2_{k}": v for k, v in
                                                        interval2_conditional_independence_tests.items()}
        else:
            y2_error_corr_given_missing_and_covered = np.nan
            y2_error_corr_given_missing_and_uncovered = np.nan
            interval1_conditional_independence_tests = {}
            interval2_conditional_independence_tests = {}

        y1_error_corr = corr(reduced_z_test, errors).item()
        subsample_idx = np.random.permutation(len(z_test))[:1500]
        y1_error_hsic = HSIC(z_test[subsample_idx], errors[subsample_idx].unsqueeze(-1)).item()
        y1_error_corr_given_m_eq_0 = corr(reduced_z_test[~deleted_test], errors[~deleted_test]).item()
        # y1_error_given_m_eq_0_corr = 0#corr(full_y_test[:, 0], errors[~deleted_test]).item()
        y2_error_corr = corr(full_y_test, errors).item()
        y2_error_corr_given_missing = corr(full_y_test[deleted_test], errors[deleted_test]).item()
        y2_error_corr_given_not_missing = corr(full_y_test[~deleted_test], errors[~deleted_test]).item()
        y2_error_sign_corr = corr(full_y_test, (errors > 0).float()).item()
        mask_error_corr = corr(deleted_test.float(), errors).item()
        idx = np.random.permutation(len(errors))[:1500]
        mask_error_hsic = HSIC(deleted_test.float()[idx], errors[idx]).item()
        model_prediction_conditional_independence_tests = run_all_tests(errors.unsqueeze(1), model_prediction, z_test)
        model_prediction_conditional_independence_tests = {f"model_pred_{k}": v for k, v in
                                                           model_prediction_conditional_independence_tests.items()}
        return {
            'y1 error corr': y1_error_corr,
            'y1 error hsic': y1_error_hsic,
            'y1 error corr given y2 exists': y1_error_corr_given_m_eq_0,

            'y2 error corr': y2_error_corr,
            'y2 error sign corr': y2_error_sign_corr,

            'y2 error corr given missing': y2_error_corr_given_missing,
            'y2 error corr given missing and covered': y2_error_corr_given_missing_and_covered,
            'y2 error corr given missing and uncovered': y2_error_corr_given_missing_and_uncovered,
            'y2 error corr given not missing': y2_error_corr_given_not_missing,

            'mask error corr': mask_error_corr,
            'mask error hsic': mask_error_hsic,

            'error given missing': errors[deleted_test].mean().item(),
            'error given not missing': errors[~deleted_test].mean().item(),

            'abs error given missing': errors[deleted_test].abs().mean().item(),
            'abs error given not missing': errors[~deleted_test].abs().mean().item(),
            "mae": mae,
            "mse": mse,
            "rmse": rmse,
            "rrmse": rrmse,
            "mape": mape,
            "smape": smape,
            "nrmse": nrmse,
            "relative_error": relative_error,
            "abs_relative_error": abs_relative_error,
            "r2": r2,
            **model_prediction_conditional_independence_tests,
            **interval1_conditional_independence_tests,
            **interval2_conditional_independence_tests,
        }
