from enum import Enum, auto

import numpy as np
import torch
from quantile_forest import RandomForestQuantileRegressor
from tqdm import tqdm

# from ConditionalGMM.ConditionalGMM import CondGMM
from clustering.kmeans_clustering import Kmeans
from clustering.linear_clustering import LinearClustering
from clustering.one_dimensional_clustering import OneDimensionalClustering
from error_sampler.CVAEErrorSampler import CVAEErrorSampler
from error_sampler.ClusteringErrorSampler import ClusteringErrorSampler
from error_sampler.ErrorSampler import ErrorSampler
from error_sampler.FlexCodeErrorSampler import FlexCodeErrorSampler
from error_sampler.KernelMixtureNetworkErrorSampler import NormalizingFlowsErrorSampler
from error_sampler.MarginalErrorSampler import MarginalErrorSampler
from error_sampler.NNKCDEErrorSampler import NNKCDEErrorSampler
from error_sampler.QRErrorSampler import QRErrorSampler
# from error_sampler.RFCDEErrorSampler import RFCDEErrorSampler
from models.regressors.regressor_factory import RegressorType
# import pypr.clustering.gmm
from sklearn.mixture import GaussianMixture



class ErrorSamplerType(Enum):
    Marginal = auto()
    LinearClustering = auto()
    LinearClusteringWithX = auto()
    KmeansClustering = auto()
    KmeansClusteringWithX = auto()
    CVAE = auto()
    QR = auto()
    Normal = auto()
    GMM = auto()
    RF = auto()
    RFCDE = auto()
    NF = auto()
    NNKCDE = auto()
    FlexCode = auto()


class RandomForestErrorSampler(ErrorSampler):
    def __init__(self, dataset_name: str, y_dim, x_dim, z_dim, device, seed: int, saved_models_path: str,
                 figures_dir: str = None):
        super().__init__()
        self.qrf = RandomForestQuantileRegressor(n_estimators=100, min_samples_leaf=10)
        self.y_dim = y_dim
        self.z_dim = z_dim
        self.device = device

    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val,
            **kwargs):
        pass

    def calibrate(self, x_cal, z_cal, y_cal, errors_cal, deleted_cal, **kwargs):
        self.qrf.fit(z_cal.cpu().detach().numpy(), errors_cal.squeeze().cpu().detach().numpy())

    def sample_error(self, x_test: torch.Tensor, z_test: torch.Tensor) -> torch.Tensor:
        errors = []
        for i in tqdm(range(z_test.shape[0]), desc='sampling errors with random forest...'):
            q = np.random.rand(1).item()
            sampled = self.qrf.predict(z_test[i].unsqueeze(0).cpu().detach().numpy(), quantiles=q)

            errors += [sampled.item()]
        return torch.Tensor(errors).to(self.device)

    @property
    def name(self) -> str:
        return "rf_error_sampler"


class ErrorSamplerFactory:
    def __init__(self, dataset_name: str, saved_models_path: str, figures_dir: str, seed: int, x_dim: int, y_dim: int,
                 z_dim: int,
                 hidden_dims, batch_norm: bool, dropout: float, lr: float, wd: float, device, scaled_y_min: float,
                 scaled_y_max: float):
        self.dataset_name = dataset_name
        self.saved_models_path = saved_models_path
        self.figures_dir = figures_dir
        self.seed = seed
        self.x_dim = x_dim
        self.z_dim = z_dim
        self.y_dim = y_dim
        self.hidden_dims = hidden_dims
        self.dropout = dropout
        self.batch_norm = batch_norm
        self.lr = lr
        self.wd = wd
        self.device = device
        self.scaled_y_min = scaled_y_min
        self.scaled_y_max = scaled_y_max

    def generate_error_sampler(self, error_sampler_type: ErrorSamplerType) -> ErrorSampler:
        error_sampler_data_name = self.dataset_name + f'_{error_sampler_type.name}_error_from_z'
        if error_sampler_type == ErrorSamplerType.Marginal:
            return MarginalErrorSampler()
        elif error_sampler_type == ErrorSamplerType.KmeansClustering:
            return ClusteringErrorSampler(Kmeans(use_x=False))
        # elif error_sampler_type == ErrorSamplerType.KmeansClusteringWithX:
        #     return ClusteringErrorSampler(Kmeans(use_x=True))
        elif error_sampler_type == ErrorSamplerType.LinearClustering:
            return ClusteringErrorSampler(
                LinearClustering(self.dataset_name, self.saved_models_path, self.figures_dir, self.seed))
        # elif error_sampler_type == ErrorSamplerType.LinearClusteringWithX:
        #     return ClusteringErrorSampler(
        #         LinearClustering(self.dataset_name, self.saved_models_path, self.figures_dir, self.seed, use_x=True))
        elif error_sampler_type == ErrorSamplerType.QR:
            return QRErrorSampler(error_sampler_data_name, self.saved_models_path, self.x_dim, self.y_dim, self.z_dim,
                                  self.scaled_y_min, self.scaled_y_max, self.hidden_dims, self.dropout,
                                  self.batch_norm,
                                  self.lr, self.wd, self.device, self.figures_dir, self.seed,
                                  )
        elif error_sampler_type == ErrorSamplerType.CVAE:
            return CVAEErrorSampler(error_sampler_data_name, self.y_dim, self.x_dim, self.z_dim, 3, self.device,
                                    seed=self.seed,
                                    saved_models_path=self.saved_models_path,
                                    dropout=self.dropout, lr=self.lr, wd=self.wd,
                                    figures_dir=self.figures_dir)
        # elif error_sampler_type == ErrorSamplerType.GMM:
        #     return GMMErrorSampler(error_sampler_data_name, self.y_dim, self.x_dim, self.z_dim, self.device,
        #                            seed=self.seed,
        #                            saved_models_path=self.saved_models_path,
        #                            figures_dir=self.figures_dir)
        # elif error_sampler_type == ErrorSamplerType.Normal:
        #     return GMMErrorSampler(error_sampler_data_name, self.y_dim, self.x_dim, self.z_dim, self.device,
        #                            seed=self.seed,
        #                            saved_models_path=self.saved_models_path,
        #                            figures_dir=self.figures_dir,
        #                            n_components=1)
        elif error_sampler_type == ErrorSamplerType.RF:
            return RandomForestErrorSampler(error_sampler_data_name, self.y_dim, self.x_dim, self.z_dim, self.device,
                                   seed=self.seed,
                                   saved_models_path=self.saved_models_path,
                                   figures_dir=self.figures_dir)
        # elif error_sampler_type == ErrorSamplerType.RFCDE:
        #     return RFCDEErrorSampler()
        elif error_sampler_type == ErrorSamplerType.FlexCode:
            return FlexCodeErrorSampler()
        elif error_sampler_type == ErrorSamplerType.NF:
            return NormalizingFlowsErrorSampler(self.dataset_name, self.saved_models_path, self.figures_dir, self.seed, self.z_dim, self.device, lr=self.lr, wd=self.wd)
        elif error_sampler_type == ErrorSamplerType.NNKCDE:
            return NNKCDEErrorSampler()
        else:
            raise Exception(f"invalid regressor type: {error_sampler_type.name}")
