import argparse
import ast
import copy
import warnings
from itertools import product
from typing import List

from calibration_schemes.AbstractCalibration import Calibration
from calibration_schemes.CQRCalibration import CQRCalibration
from calibration_schemes.OracleCQRCalibration import OracleCQRCalibration
from calibration_schemes.TriplyRobustCalibration import TriplyRobustCalibration
from calibration_schemes.WeightedCalibration import WeightedCalibration
from calibration_schemes.PrivilegedConformalPrediction import PrivilegedConformalPrediction
from calibration_schemes.TwoStagedConformalPrediction import TwoStagedCalibration
from calibration_schemes.DummyCalibration import DummyCalibration
from data_utils.data_corruption.data_corruption_masker import DataCorruptionMasker
from data_utils.data_type import DataType
from data_utils.dataset_naming_utils import get_z_dim_from_data_name
from data_utils.datasets.dataset import Dataset
from data_utils.datasets.regression_dataset import RegressionDataset
from data_utils.datasets.synthetic_dataset_generator import PartiallyLinearDataGenerator
from data_utils.get_dataset_utils import get_regression_dataset
from error_sampler.OracleErrorSampler import OracleErrorSampler
from get_model_utils import get_proxy_qr_model, get_data_learning_mask_estimator, is_data_for_xgboost
from imputation_methods.regression_imputations.BadImputation import BadImputation
from imputation_methods.regression_imputations.OracleImputation import OracleImputation
from models.data_mask_estimators.BadOracleDataMasker import BadOracleDataMasker
from models.data_mask_estimators.OracleDataMasker import OracleDataMasker
from models.qr_models.BadQuantileRegression import BadQuantileRegression
from models.qr_models.OracleQuantileRegression import OracleQuantileRegression
from models.qr_models.PredictionIntervalModel import PredictionIntervalModel
from models.qr_models.QuantileRegression import QuantileRegression
from models.data_mask_estimators.DataMaskEstimator import DataMaskEstimator
from models.data_mask_estimators.NetworkMaskEstimator import NetworkMaskEstimator, XGBoostMaskEstimator, RFMaskEstimator
from models.qr_models.XGBoostQR import XGBoostQR
from models.regressors.regressor_factory import RegressorType, RegressorFactory
from regression_main import run_experiment, parse_args
from results_helper.regression_results_helper import RegressionResultsHelper
from calibration_schemes.CalibrationByImputation import CalibrationByImputation
from clustering.kmeans_clustering import Kmeans
from clustering.linear_clustering import LinearClustering
from error_sampler.error_sampler_factory import ErrorSamplerFactory, ErrorSamplerType
from imputation_methods.ConditionalSampleImputator import ConditionalSampleImputator
from imputation_methods.ImputationMethod import ImputationMethod
from imputation_methods.SampleImputator import SampleImputator
from imputation_methods.regression_imputations.RegressorImputation import RegressorImputation
from imputation_methods.regression_imputations.RegressorImputationWithErrorSampling import \
    RegressorImputationWithErrorSampling
from utils import set_seeds
import matplotlib
from sys import platform

if platform not in ['win32', 'darwin']:
    matplotlib.use('Agg')

warnings.filterwarnings("ignore")


def get_oracle_imputators(dataset_name, x_dim, y_dim, z_dim, args, scaled_y_max, scaled_y_min, saved_models_path,
                          figures_dir, seed, dataset: RegressionDataset, device) -> List[
    ImputationMethod]:
    regressor_types = [
        RegressorType.Linear,
        # RegressorType.PartiallyLinear,
        RegressorType.Full,
        RegressorType.FullWithLinearity
    ]
    regressor_factory = RegressorFactory(dataset_name, args.saved_models_path, args.figures_dir, args.seed, x_dim,
                                         y_dim,
                                         z_dim, args.hidden_dims, args.batch_norm, args.dropout, args.lr, args.wd,
                                         args.device)
    imputators = [OracleImputation(dataset_name, x_dim, z_dim, device, dataset.scaler, seed), BadImputation()]
    return imputators

    # for regressor_type in regressor_types:
    #     regressor = regressor_factory.generate_regressor(regressor_type)
    #     error_sampler = OracleErrorSampler(dataset.dataset_name, x_dim, z_dim, regressor)
    #     imputators.append(RegressorImputationWithErrorSampling(regressor, error_sampler))

    error_sampler_types = [
        ErrorSamplerType.Marginal,
        ErrorSamplerType.LinearClustering,
        ErrorSamplerType.KmeansClustering,
        # ErrorSamplerType.GMM,
        # ErrorSamplerType.RF,
        # ErrorSamplerType.CVAE,
        # ErrorSamplerType.QR,
    ]
    regressor_factory = RegressorFactory(dataset_name, args.saved_models_path, args.figures_dir, args.seed, x_dim,
                                         y_dim,
                                         z_dim, args.hidden_dims, args.batch_norm, args.dropout, args.lr, args.wd,
                                         args.device)
    error_sampler_factory = ErrorSamplerFactory(dataset_name, args.saved_models_path, args.figures_dir, args.seed,
                                                x_dim, y_dim,
                                                z_dim, args.hidden_dims, args.batch_norm, args.dropout, args.lr,
                                                args.wd, args.device,
                                                scaled_y_min, scaled_y_max)
    # TODO: Maybe delete this
    # for regressor_type in regressor_types:
    #     for error_sampler_type in error_sampler_types:
    #         error_sampler = error_sampler_factory.generate_error_sampler(error_sampler_type, regressor_type)
    #         regressor = regressor_factory.generate_regressor(regressor_type)
    #         imputator = RegressorImputationWithErrorSampling(regressor, error_sampler)
    #         imputators.append(imputator)
    #     regressor = regressor_factory.generate_regressor(regressor_type)
    #     imputators.append(RegressorImputation(regressor))

    return imputators


def get_calibration_schemes(dataset, args) -> List[Calibration]:
    data_scaler = dataset.scaler
    data_masker = dataset.data_masker
    dataset_name = dataset.dataset_name
    device = args.device
    z_dim = dataset.z_dim
    x_dim = dataset.x_dim
    y_dim = dataset.y_dim
    scaled_y_max = dataset.scaled_y_max
    scaled_y_min = dataset.scaled_y_min
    alpha = args.alpha
    figures_dir = args.figures_dir
    saved_models_path = args.saved_models_path
    seed = args.seed
    calibration_schemes = [
        DummyCalibration(alpha),
        # CQRCalibration(alpha), CQRCalibration(alpha, ignore_masked=True),
        OracleCQRCalibration(alpha)]

    def get_data_mask_estimators() -> List[DataMaskEstimator]:
        return [
            OracleDataMasker(data_scaler, data_masker, dataset_name, x_dim, z_dim, gamma=args.gamma),
            BadOracleDataMasker(data_scaler, data_masker, dataset_name, x_dim, z_dim, gamma=args.gamma),
        ]

    # for data_mask_estimator in get_data_mask_estimators():
    #     calibration_schemes.append(
    #         WeightedCalibration(CQRCalibration(alpha), alpha, dataset_name, data_scaler, data_mask_estimator))
    #
    # for data_mask_estimator in get_data_mask_estimators():
    #     calibration_schemes.append(
    #         PrivilegedConformalPrediction(CQRCalibration(alpha), alpha, dataset_name, data_scaler, data_mask_estimator))

    # for imputator in get_oracle_imputators(dataset_name, x_dim, y_dim, z_dim, args, scaled_y_max, scaled_y_min,
    #                                        saved_models_path, figures_dir, seed, dataset, device):
    #     calibration_schemes.append(CalibrationByImputation(imputator, CQRCalibration(alpha), alpha))

    for i in range(len(get_oracle_imputators(dataset_name, x_dim, y_dim, z_dim, args, scaled_y_max, scaled_y_min,
                                             saved_models_path, figures_dir, seed, dataset, device))):
        for data_mask_estimator in get_data_mask_estimators():
            imputator = get_oracle_imputators(dataset_name, x_dim, y_dim, z_dim, args, scaled_y_max, scaled_y_min,
                                              saved_models_path, figures_dir, seed, dataset, device)[i]
            calibration_by_imputation = CalibrationByImputation(imputator, CQRCalibration(alpha), alpha)
            pcp = PrivilegedConformalPrediction(CQRCalibration(alpha), alpha, dataset_name, data_scaler,
                                                data_mask_estimator, device=device)
            triply_robust = TriplyRobustCalibration(alpha, calibration_by_imputation, pcp)
            calibration_schemes.append(triply_robust)

    return calibration_schemes


def get_models(dataset: Dataset, args) -> List[PredictionIntervalModel]:
    return [
        OracleQuantileRegression(dataset.dataset_name, dataset.original_dataset_name, x_dim=dataset.x_dim,
                                 y_dim=dataset.y_dim, z_dim=dataset.z_dim, alpha=args.alpha, data_scaler=dataset.scaler,
                                 device=args.device, seed=args.seed),
        BadQuantileRegression(args.alpha)
    ]


def main(args=None):
    if args is None:
        args = parse_args()
    if args.multi_run:
        for seed in range(3, 30):
            args.seed = seed
            run_experiment(args, get_models, get_calibration_schemes)
    else:
        run_experiment(args, get_models, get_calibration_schemes)


if __name__ == '__main__':
    main()
