# from typing import List
#
# import numpy as np
# import torch
#
# from error_sampler.ErrorSampler import ErrorSampler
# from models.abstract_models.NetworkLearningModel import NetworkLearningModel
# from models.qr_models.QuantileRegression import QuantileRegression
# # import normflows as nf
#
#
# class NormalizingFlowsErrorSampler(ErrorSampler, NetworkLearningModel):
#
#
#     def __init__(self):
#         super().__init__()
#         base = nf.distributions.base.DiagGaussian(2)
#
#         # Define list of flows
#         num_layers = 32
#         flows = []
#         for i in range(num_layers):
#             # Neural network with two hidden layers having 64 units each
#             # Last layer is initialized by zeros making training more stable
#             param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)
#             # Add flow layer
#             flows.append(nf.flows.AffineCouplingBlock(param_map))
#             # Swap dimensions
#             flows.append(nf.flows.Permute(2, mode='swap'))
#             # If the target density is not given
#
#         self.model = nf.NormalizingFlow(base, flows)
#
#     def predict(self, x, **kwargs):
#         return self.model.forward_kld(x)
#
#
#     def loss(self, y, prediction, d, epoch, **kwargs):
#         loss = self.model.forward_kld(x)
#         pass
#
#     # def loss(self):
#     #     loss = self.model.forward_kld(x)
#     #
#     #     # When minimizing the reverse KLD based on the given target distribution
#     #     loss = model.reverse_kld(num_samples=512)
#     #
#     #     # Optimization as usual
#     #     loss.backward()
#     #     optimizer.step()
#
#     def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val,
#             **kwargs):
#         new_z_train = torch.cat([z_train, z_val], dim=0)
#         new_d_train = torch.cat([deleted_train, deleted_val], dim=0)
#         new_error_train = torch.cat([errors_train, errors_val], dim=0)
#         self.forest.train(new_z_train[~new_d_train].detach().cpu().numpy().astype(np.double),
#                           new_error_train[~new_d_train].detach().cpu().numpy().astype(np.double))
#         # self.forest.
#
#     def sample_error(self, x_test, z_test):
#         rnd_quantile_levels = torch.rand(len(z_test)).detach().cpu().numpy().astype(np.double)
#         sampled_errors = torch.zeros(len(z_test)).to(z_test.device)
#         for i in range(len(z_test)):
#             sampled_errors[i] = self.forest.predict_quantile(z_test.detach().cpu().numpy().astype(np.double)[i],
#                                                              rnd_quantile_levels[i].item()).item()
#         # sampled_errors = self.forest.predict_quantile(z_test.detach().cpu().numpy().astype(np.double), rnd_quantile_levels)
#         # return torch.Tensor(sampled_errors).to(z_test.device)
#         return sampled_errors
#
#     @property
#     def name(self) -> str:
#         return "nf_error_sampler"
