from typing import List

import numpy as np
import torch

from error_sampler.ErrorSampler import ErrorSampler
from error_sampler.utils import sample_from_pdf
from models.abstract_models.NetworkLearningModel import NetworkLearningModel
from models.qr_models.QuantileRegression import QuantileRegression
import nflows
from nflows import transforms, distributions, flows

transform = transforms.CompositeTransform([
    transforms.MaskedAffineAutoregressiveTransform(features=2, hidden_features=4),
    transforms.RandomPermutation(features=2)
])

import matplotlib.pyplot as plt
import sklearn.datasets as datasets

import torch
from torch import nn
from torch import optim

from nflows.flows.base import Flow
from nflows.distributions.normal import ConditionalDiagonalNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.nn.nets import ResidualNet


# pip install cde
class NormalizingFlowsErrorSampler(ErrorSampler, NetworkLearningModel):

    def __init__(self, dataset_name: str, saved_models_path: str, figures_dir: str, seed: int, z_dim: int, device, lr:float, wd: float):
        ErrorSampler.__init__(self)
        NetworkLearningModel.__init__(self, dataset_name, saved_models_path, figures_dir, seed, lr=lr, wd=wd)

        num_layers = 5
        base_dist = ConditionalDiagonalNormal(shape=[1],
                                              context_encoder=nn.Linear(z_dim, 2))

        transforms = []
        for _ in range(num_layers):
            transforms.append(ReversePermutation(features=1))
            transforms.append(MaskedAffineAutoregressiveTransform(features=1,
                                                                  hidden_features=8,
                                                                  context_features=z_dim))
        transform = CompositeTransform(transforms)

        self._network = Flow(transform, base_dist).to(device)
        self._optimizer = optim.Adam(self._network.parameters(), lr=lr, weight_decay=wd)

    def loss(self, y, prediction, d, epoch, **kwargs):
        return -self._network.log_prob(inputs=y, context=kwargs['x']).mean()

    def predict(self, x, **kwargs):
        pass

    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val,
            **kwargs):
        # new_z_train = torch.cat([z_train, z_val], dim=0)
        # new_d_train = torch.cat([deleted_train, deleted_val], dim=0)
        # new_error_train = torch.cat([errors_train, errors_val], dim=0)
        self.min_val = errors_train.min().item()
        self.max_val = errors_train.max().item()
        NetworkLearningModel.fit_xy(self, z_train, errors_train, deleted_train, z_val, errors_val, deleted_val, **kwargs)

    def sample_error(self, x_test, z_test):
        return self._network.sample(1, z_test).squeeze()

    @property
    def name(self) -> str:
        return "nf_error_sampler"
