import argparse
import ast
import copy
import warnings
from itertools import product
from typing import List

import torch
import traceback

from quantile_forest import RandomForestQuantileRegressor

from calibration_schemes.AbstractCalibration import Calibration
from calibration_schemes.CQRCalibration import CQRCalibration
from calibration_schemes.OracleCQRCalibration import OracleCQRCalibration
from calibration_schemes.ResidualScoreCalibration import ResidualScoreCalibration
from calibration_schemes.TriplyRobustCalibration import TriplyRobustCalibration
from calibration_schemes.WeightedCalibration import WeightedCalibration
from calibration_schemes.PrivilegedConformalPrediction import PrivilegedConformalPrediction
from calibration_schemes.TwoStagedConformalPrediction import TwoStagedCalibration
from calibration_schemes.DummyCalibration import DummyCalibration
from data_utils.data_corruption.data_corruption_masker import DataCorruptionMasker
from data_utils.data_type import DataType
from data_utils.dataset_naming_utils import get_z_dim_from_data_name
from data_utils.datasets.dataset import Dataset
from data_utils.datasets.regression_dataset import RegressionDataset
from data_utils.datasets.synthetic_dataset_generator import PartiallyLinearDataGenerator
from data_utils.get_dataset_utils import get_regression_dataset
from error_sampler.OracleErrorSampler import OracleErrorSampler
from get_model_utils import get_proxy_qr_model, get_data_learning_mask_estimator, is_data_for_xgboost, is_data_for_rf
from imputation_methods.regression_imputations.BadImputation import BadImputation
from imputation_methods.regression_imputations.OracleImputation import OracleImputation
from models.data_mask_estimators.BadOracleDataMasker import BadOracleDataMasker
from models.data_mask_estimators.OracleDataMasker import OracleDataMasker
from models.data_mask_estimators.OracleDataMaskerWithDelta import OracleDataMaskerWithDelta
from models.qr_models.BadQuantileRegression import BadQuantileRegression
from models.qr_models.OracleQuantileRegression import OracleQuantileRegression
from models.qr_models.PredictionIntervalModel import PredictionIntervalModel
from models.qr_models.QuantileRegression import QuantileRegression
from models.data_mask_estimators.DataMaskEstimator import DataMaskEstimator
from models.data_mask_estimators.NetworkMaskEstimator import NetworkMaskEstimator, XGBoostMaskEstimator, RFMaskEstimator
from models.qr_models.RFQR import RFQR
from models.qr_models.XGBoostQR import XGBoostQR
from models.regressors.FullRegressor import FullRegressor
from models.regressors.regressor_factory import RegressorType, RegressorFactory
from plot_utils import display_plot
from results_helper.regression_results_helper import RegressionResultsHelper
from calibration_schemes.CalibrationByImputation import CalibrationByImputation
from clustering.kmeans_clustering import Kmeans
from clustering.linear_clustering import LinearClustering
from error_sampler.error_sampler_factory import ErrorSamplerFactory, ErrorSamplerType
from imputation_methods.ConditionalSampleImputator import ConditionalSampleImputator
from imputation_methods.ImputationMethod import ImputationMethod
from imputation_methods.SampleImputator import SampleImputator
from imputation_methods.regression_imputations.RegressorImputation import RegressorImputation
from imputation_methods.regression_imputations.RegressorImputationWithErrorSampling import \
    RegressorImputationWithErrorSampling
from utils.utils import set_seeds
import matplotlib
from sys import platform

if platform not in ['win32', 'darwin']:
    matplotlib.use('Agg')

warnings.filterwarnings("ignore")


def parse_args_utils(args):
    args.hidden_dims = ast.literal_eval(args.hidden_dims)
    args.batch_norm = args.batch_norm > 0
    args.data_type = DataType.Real if args.data_type.lower() == 'real' else DataType.Synthetic
    device_name = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    args.device = torch.device(device_name)
    args.z_dim = get_z_dim_from_data_name(args.dataset_name)
    args.multi_run = args.multi_run > 0
    args.gamma = 1

    if 'facebook' in args.dataset_name:
        args.hidden_dims = [64, 128, 64, 32]
    else:
        args.hidden_dims = [32, 64, 64, 32]

    print(f"device: {device_name}")
    return args


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_type', type=str, default="Synthetic",
                        help='type of data set. real or synthetic. REAL for real. SYN for synthetic')
    parser.add_argument('--x_dim', type=int, default=10,
                        help='x dim of synthetic dataset')

    parser.add_argument('--dataset_name', type=str, default='partially_linear_syn',
                        help='dataset to use')
    parser.add_argument('--data_path', type=str, default='datasets/real_data',
                        help='')
    parser.add_argument('--non_linearity', type=str, default="lrelu",
                        help='')
    parser.add_argument('--dropout', type=float, default=0.1,
                        help='')

    parser.add_argument('--data_size', type=int, default=30000,
                        help='')
    parser.add_argument('--hidden_dims', type=str, default='[32, 64, 64, 32]',
                        help='')
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed')
    parser.add_argument('--alpha', type=float, default=0.1,
                        help='risk level')
    parser.add_argument('--bs', type=int, default=128,
                        help='batch size')
    parser.add_argument('--wait', type=int, default=200,
                        help='batch size')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument('--wd', type=float, default=0.0,
                        help='weight decay')
    parser.add_argument('--base_results_save_dir', type=str, default="./results",
                        help="results save dir")

    parser.add_argument('--training_ratio', type=float, default=0.5,
                        help="fraction of samples used for test")
    parser.add_argument('--validation_ratio', type=float, default=0.1,
                        help="fraction of samples used for validation")
    parser.add_argument('--calibration_ratio', type=float, default=0.2,
                        help="fraction of samples used for validation")
    parser.add_argument('--epochs', type=int, default=5000,
                        help="number of epochs for offline training")
    parser.add_argument('--figures_dir', type=str, default='./figures',
                        help="figures_dir")
    parser.add_argument('--saved_models_path', type=str, default='./saved_models',
                        help="saved_models_path")
    parser.add_argument('--batch_norm', type=int, default=0,
                        help="batch norm")
    parser.add_argument('--multi_run', type=int, default=0,
                        help="multi_run")
    parser.add_argument('--data_masker', type=str, default='', help='data masker to use')

    parser.add_argument('--beta', type=float, default=0.005, help='data masker to use')

    args = parser.parse_args()
    args = parse_args_utils(args)
    return args


def get_imputators(dataset_name, x_dim, y_dim, z_dim, args, scaled_y_max: float, scaled_y_min: float,
                   saved_models_path: str, figures_dir: str, seed, dataset, ) -> List[
    ImputationMethod]:
    imputators = [
                # SampleImputator(),
                #   ConditionalSampleImputator(LinearClustering(dataset_name, saved_models_path, figures_dir, seed),
                #                              dataset, args),
                #   ConditionalSampleImputator(Kmeans(), dataset, args),
                  ]
    regressor_types = [
        # RegressorType.Linear,
        # RegressorType.Full,
        RegressorType.FullWithLinearity,
    ]

    error_sampler_types = [
        # ErrorSamplerType.Marginal,
        ErrorSamplerType.LinearClustering,
        # ErrorSamplerType.KmeansClustering,
    ]
    regressor_factory = RegressorFactory(dataset_name, args.saved_models_path, args.figures_dir, args.seed, x_dim,
                                         y_dim,
                                         z_dim, args.hidden_dims, args.batch_norm, args.dropout, args.lr, args.wd,
                                         args.device)
    error_sampler_factory = ErrorSamplerFactory(dataset_name, args.saved_models_path, args.figures_dir, args.seed,
                                                x_dim, y_dim,
                                                z_dim, args.hidden_dims, args.batch_norm, args.dropout, args.lr,
                                                args.wd, args.device,
                                                scaled_y_min, scaled_y_max)
    for regressor_type in regressor_types:
        for error_sampler_type in error_sampler_types:
            error_sampler = error_sampler_factory.generate_error_sampler(error_sampler_type)
            regressor = regressor_factory.generate_regressor(regressor_type)
            imputator = RegressorImputationWithErrorSampling(regressor, error_sampler)
            imputators.append(imputator)
        # regressor = regressor_factory.generate_regressor(regressor_type)
        # imputators.append(RegressorImputation(regressor))
    return imputators


def get_calibration_schemes(dataset, args) -> List[Calibration]:
    data_scaler = dataset.scaler
    data_masker = dataset.data_masker
    dataset_name = dataset.dataset_name
    device = args.device
    z_dim = dataset.z_dim
    x_dim = dataset.x_dim
    y_dim = dataset.y_dim
    scaled_y_max = dataset.scaled_y_max
    scaled_y_min = dataset.scaled_y_min
    alpha = args.alpha
    figures_dir = args.figures_dir
    saved_models_path = args.saved_models_path
    seed = args.seed
    calibration_schemes = []

    calibration_schemes += [
        # DummyCalibration(alpha), CQRCalibration(alpha), CQRCalibration(alpha, ignore_masked=True)
    ]

    def get_data_mask_estimators() -> List[DataMaskEstimator]:
        return [
            OracleDataMasker(data_scaler, data_masker, dataset_name, x_dim, z_dim),
            get_data_learning_mask_estimator(args, x_dim, z_dim),
            RFMaskEstimator(args.dataset_name, args.saved_models_path, x_dim, z_dim, device=device,
                            seed=args.seed),
            NetworkMaskEstimator(args.dataset_name, args.saved_models_path, x_dim, z_dim, args.hidden_dims,
                                 args.dropout,
                                 args.batch_norm, args.lr, args.wd, device=device, figures_dir=args.figures_dir,
                                 seed=args.seed),
            XGBoostMaskEstimator(args.dataset_name, args.saved_models_path, x_dim, z_dim, device=device,
                                 seed=args.seed),
            get_data_learning_mask_estimator(args, x_dim, 0),
        ]

    # for data_mask_estimator in get_data_mask_estimators():
    #     calibration_schemes.append(
    #         WeightedCalibration(CQRCalibration(alpha), alpha, dataset_name, data_scaler, data_mask_estimator, device=device))
    #
    # for data_mask_estimator in get_data_mask_estimators():
    #     calibration_schemes.append(
    #         PrivilegedConformalPrediction(CQRCalibration(alpha), alpha, dataset_name, data_scaler, data_mask_estimator, beta=args.beta, device=device))

    imputators = get_imputators(dataset_name, x_dim, y_dim, z_dim, args, scaled_y_max, scaled_y_min,
                                    saved_models_path, figures_dir, seed, dataset)
    for imputator in imputators:
        calibration_schemes.append(CalibrationByImputation(imputator, CQRCalibration(alpha), alpha))

    # for i in range(len(imputators)):
    #     for data_mask_estimator in get_data_mask_estimators():
    #         imputator = get_imputators(dataset_name, x_dim, y_dim, z_dim, args, scaled_y_max, scaled_y_min,
    #                                 saved_models_path, figures_dir, seed, dataset)[i]
    #         calibration_by_imputation = CalibrationByImputation(imputator, CQRCalibration(alpha), alpha)
    #         pcp = PrivilegedConformalPrediction(CQRCalibration(alpha), alpha, dataset_name, data_scaler,
    #                                             data_mask_estimator, device=device)
    #         triply_robust = TriplyRobustCalibration(alpha, calibration_by_imputation, pcp)
    #         calibration_schemes.append(triply_robust)

    return calibration_schemes


def get_models(dataset: RegressionDataset, args) -> List[PredictionIntervalModel]:
    alpha = args.alpha
    if is_data_for_rf(dataset.dataset_name):
        model = RFQR(dataset.dataset_name, args.saved_models_path, seed=args.seed, alpha=alpha)
    elif is_data_for_xgboost(dataset.dataset_name):
        model = XGBoostQR(dataset.dataset_name, args.saved_models_path, seed=args.seed, alpha=alpha)
    else:
        model = QuantileRegression(dataset.dataset_name, args.saved_models_path, dataset.x_dim, dataset.y_dim, alpha,
                                   hidden_dims=args.hidden_dims, dropout=args.dropout, lr=args.lr, wd=args.wd,
                               device=args.device, figures_dir=args.figures_dir, seed=args.seed)
    return [model]


def run_experiment(args, models_generator, calibration_schemes_generator, store_train_performance=True):
    print(f"starting seed: {args.seed} data: {args.dataset_name}")
    results_helper = RegressionResultsHelper(args.base_results_save_dir, args.seed)
    set_seeds(args.seed)
    dataset = get_regression_dataset(args)
    set_seeds(args.seed)
    models = models_generator(dataset, args)


    for model in models:
        model.fit(dataset.x_train, dataset.y_train, dataset.deleted_train, dataset.x_val, dataset.y_val, dataset.deleted_val, batch_size=args.bs,
                  wait=args.wait, epochs=args.epochs)
        model.eval()
        calibration_schemes = calibration_schemes_generator(dataset, args)
        train_uncalibrated_intervals = model.construct_uncalibrated_intervals(dataset.x_train)
        val_uncalibrated_intervals = model.construct_uncalibrated_intervals(dataset.x_val)
        cal_uncalibrated_intervals = model.construct_uncalibrated_intervals(dataset.x_cal)
        test_uncalibrated_intervals = model.construct_uncalibrated_intervals(dataset.x_test)
        for calibration_scheme in calibration_schemes:
            try:
                print(f"started working on calibration: {calibration_scheme.name} with model {model.name}")
                calibration_scheme.fit(dataset.x_train, dataset.y_train, deleted_train=dataset.deleted_train,
                                       x_val=dataset.x_val,
                                       y_val=dataset.y_val, deleted_val=dataset.deleted_val, epochs=args.epochs, batch_size=args.bs,
                                       n_wait=args.wait, z_train=dataset.z_train, z_val=dataset.z_val,
                                       train_uncalibrated_intervals=train_uncalibrated_intervals,
                                       val_uncalibrated_intervals=val_uncalibrated_intervals)
                calibration_scheme.calibrate(dataset.x_cal, dataset.y_cal, dataset.z_cal, dataset.deleted_cal,
                                             cal_uncalibrated_intervals,
                                             full_y_cal=dataset.full_y_cal)  # For oracle models
                test_calibrated_intervals = calibration_scheme.construct_calibrated_uncertainty_sets(dataset.x_test,
                                                                                                     test_uncalibrated_intervals,
                                                                                                     z_test=dataset.z_test)
                if store_train_performance:
                    train_calibrated_intervals = calibration_scheme.construct_calibrated_uncertainty_sets(dataset.x_train,
                                                                                                          train_uncalibrated_intervals,
                                                                                                          z_test=dataset.z_train)
                else:
                    train_calibrated_intervals = None

                results_helper.save_performance_metrics(train_uncalibrated_intervals, train_calibrated_intervals,
                                                        test_uncalibrated_intervals, test_calibrated_intervals,
                                                        dataset, model, calibration_scheme)
            except Exception as e:
                print(f"error: failed on calibration scheme: {calibration_scheme.name} because {e}")
                traceback.print_exc()
                print()


def main(args=None):
    if args is None:
        args = parse_args()
    if args.multi_run:
        for seed in range(0, 30):
            args.seed = seed
            run_experiment(args, get_models, get_calibration_schemes, store_train_performance=False)
    else:
        run_experiment(args, get_models, get_calibration_schemes, store_train_performance=False)


if __name__ == '__main__':
    main()
