import argparse
import warnings
from typing import List

import numpy as np

from calibration_schemes.AbstractCalibration import Calibration
from calibration_schemes.CQRCalibration import CQRCalibration
from calibration_schemes.WeightedCalibration import WeightedCalibration
from calibration_schemes.PrivilegedConformalPrediction import PrivilegedConformalPrediction
from data_utils.datasets.regression_dataset import RegressionDataset
from models.data_mask_estimators.OracleDataMaskerWithDelta import OracleDataMaskerWithDelta
from models.data_mask_estimators.DataMaskEstimator import DataMaskEstimator
from models.data_mask_estimators.OracleDataMaskerWithDeltaMinMax import OracleDataMaskerWithUniformDeltaMinMax, \
    OracleDataMaskerWithLeftSidedDeltaMinMax, OracleDataMaskerWithSmallTailsDeltaMinMax, \
    OracleDataMaskerWithExtremeTailsDeltaMinMax, OracleDataMaskerWithRightSidedDeltaMinMax, \
    OracleDataMaskerWithBetaUDeltaMinMax, OracleDataMaskerWithBetaRightDeltaMinMax, \
    OracleDataMaskerWithBetaLeftDeltaMinMax, OracleDataMaskerWithBetaRightHillDeltaMinMax, \
    OracleDataMaskerWithBetaLeftHillDeltaMinMax
from regression_main import run_experiment, parse_args, get_models
import matplotlib
from sys import platform

if platform not in ['win32', 'darwin']:
    matplotlib.use('Agg')

warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser()


def get_calibration_schemes(dataset: RegressionDataset, args) -> List[Calibration]:
    data_scaler = dataset.scaler
    data_masker = dataset.data_masker
    dataset_name = dataset.dataset_name
    z_dim = dataset.z_dim
    x_dim = dataset.x_dim
    alpha = args.alpha
    calibration_schemes = []
    # calibration_schemes = [DummyCalibration(alpha), CQRCalibration(alpha), CQRCalibration(alpha,
    #                                                                                       ignore_masked=True)]
    # calibration_schemes.append(
    #     WeightedCalibration(CQRCalibration(alpha), alpha, dataset_name, data_scaler,
    #                          OracleDataMasker(data_scaler, data_masker, dataset_name, x_dim, z_dim),
    #                         ))

    corruption_probabilities = data_masker.get_corruption_probabilities(dataset.unscaled_x, dataset.unscaled_z)
    weights = (1 - corruption_probabilities.mean()) / (1 - corruption_probabilities)
    weights_sum = weights[~dataset.d].sum().item()
    c = weights_sum / (len(weights[~dataset.d]) + 1)
    if args.data_masker == 'uniform_delta_min_max':
        data_masker_class = OracleDataMaskerWithUniformDeltaMinMax
    elif args.data_masker == 'left_sided_delta_min_max':
        data_masker_class = OracleDataMaskerWithLeftSidedDeltaMinMax
    elif args.data_masker == 'right_sided_delta_min_max':
        data_masker_class = OracleDataMaskerWithRightSidedDeltaMinMax
    elif args.data_masker == 'extreme_tails_delta_min_max':
        data_masker_class = OracleDataMaskerWithExtremeTailsDeltaMinMax
    elif args.data_masker == 'small_tails_delta_min_max':
        data_masker_class = OracleDataMaskerWithSmallTailsDeltaMinMax
    elif args.data_masker == 'beta_u_delta_min_max':
        data_masker_class = OracleDataMaskerWithBetaUDeltaMinMax
    elif args.data_masker == 'beta_right_delta_min_max':
        data_masker_class = OracleDataMaskerWithBetaRightDeltaMinMax
    elif args.data_masker == 'beta_left_delta_min_max':
        data_masker_class = OracleDataMaskerWithBetaLeftDeltaMinMax
    elif args.data_masker == 'beta_right_hill_delta_min_max':
        data_masker_class = OracleDataMaskerWithBetaRightHillDeltaMinMax
    elif args.data_masker == 'beta_left_hill_delta_min_max':
        data_masker_class = OracleDataMaskerWithBetaLeftHillDeltaMinMax
    elif args.data_masker == 'delta':
        data_masker_class = OracleDataMaskerWithDelta
    else:
        raise Exception(f"does not recognize args.data_masker={args.data_masker}")

    def get_data_mask_estimators() -> List[DataMaskEstimator]:
        v = max(abs(weights.max().item()), abs(weights.min().item()))

        if args.data_masker == 'delta':
            deltas = list(np.arange(-v, v, v / 40)) + [0]
            deltas += list(np.arange(-1, 1, 0.05) - c) + list(np.arange(-0.03, 0.03, 0.01) - c)
            deltas = np.unique([np.round(d, 3) for d in deltas])
            print("# delta: ", len(deltas))
            deltas.sort()
            return [OracleDataMaskerWithDelta(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta=d) for d in deltas]
        elif 'min_max' in args.data_masker:
            deltas = list(np.arange(-v, v, v / 20)) + [0]
            deltas = np.unique([np.round(d, 3) for d in deltas])
            res = [data_masker_class(data_scaler, data_masker, dataset_name, x_dim, z_dim,
                                     delta_min=d_min, delta_max=d_max) for d_min in deltas for d_max in deltas if
                   d_max - d_min > 1e-3]
            # res = [data_masker_class(data_scaler, data_masker, dataset_name, x_dim, z_dim,
            #                          delta_min=-3.64, delta_max=2.14)]
            print("# delta: ", len(res))
            return res
        else:
            raise Exception(f"does not recognize args.data_masker={args.data_masker}")

    # for data_mask_estimator in get_data_mask_estimators():
    #     calibration_schemes.append(
    #         WeightedCalibration(CQRCalibration(alpha), alpha, dataset_name, data_scaler, data_mask_estimator, device=args.device,
    #                             quick_mode=True))

    for data_mask_estimator in get_data_mask_estimators():
        calibration_schemes.append(
            PrivilegedConformalPrediction(CQRCalibration(alpha), alpha, dataset_name, data_scaler, data_mask_estimator, args.device))

    return calibration_schemes


def main(args=None):
    if args is None:
        args = parse_args()
    if args.multi_run:
        for seed in range(0, 1):
            args.seed = seed
            run_experiment(args, get_models, get_calibration_schemes, store_train_performance=False)
    else:
        run_experiment(args, get_models, get_calibration_schemes, store_train_performance=False)


if __name__ == '__main__':
    main()
