import torch
import normflows as nf
import matplotlib.pyplot as plt

from torch.utils.data import TensorDataset, DataLoader
from sklearn.neighbors import KernelDensity
from sklearn.mixture import GaussianMixture
from tqdm.notebook import tqdm

from .flows import HyperCouplingBlock
from .nets import VAE

class Model():
    """ Base class for models.
    """

    def __init__(self, name):
        self.name = name

    def fit(self, X, device):
        """ Fit model.
        
        Arguments:
        ----------
        X : torch.Tensor, shape (n_samples, n_features)
            Training data set.
        """
        raise NotImplementedError

    def sample(self, n_samples, device):
        """ Sample from trained model.
        
        Arguments:
        ----------
        n_samples : int
            Number of samples.
        device : torch.Device
            Device of output tensor.
            
        Returns:
        --------
        sample : torch.Tensor, shape (n_samples, n_features)
            Sample from model.
        """
        raise NotImplementedError

######################################################################################################
### Wrappers

def get_batch_size(n_samples):

    if n_samples <= 100:
        bs = n_samples
    else:
        bs = n_samples // 4

    return bs

class KDEModel(Model):

    def __init__(self, name=None, kernel='gaussian', bandwidth=0.5, **kwargs):
        self.name = 'KDE' if name is None else name
        self.model = KernelDensity(kernel=kernel, bandwidth=bandwidth, **kwargs)

    def fit(self, X, device):
        self.model.fit(X.to('cpu'))

    def sample(self, n_samples, device='cpu'):
        sample = torch.tensor(self.model.sample(n_samples), dtype=torch.float, device=device) 
        return sample


class NFModel(Model):

    def __init__(self, n_features, name=None, num_layers=16, hidden_layers=[16, 16]):
        
        self.name = 'RealNVP' if name is None else name
        base = nf.distributions.base.DiagGaussian(n_features, trainable=False)
        flows = []
        for i in range(num_layers):
            param_map = nf.nets.MLP([n_features//2] + hidden_layers + [2], init_zeros=True)
            flows.append(nf.flows.AffineCouplingBlock(param_map))
            flows.append(nf.flows.Permute(n_features, mode='shuffle'))
        self.model = nf.NormalizingFlow(base, flows)

    def fit(self, X, device, epochs=100, bs=None, lr=1e-3, milestones=[], gamma=0.1, verbose=True):

        if bs is None:
            bs = get_batch_size(X.shape[0])
        
        self.model.to(device)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
        dataset = TensorDataset(X.to(device))
        loader = DataLoader(dataset, batch_size=bs, shuffle=True)
        losses = []
        loop = tqdm(range(epochs)) if verbose else range(epochs)
        for epoch in loop:
            for x, in loader:
                optimizer.zero_grad()
                loss = self.model.forward_kld(x)
                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    optimizer.step()
            
            scheduler.step()
            losses.append(loss.item())
            if verbose:
                loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        
        if verbose:
            plt.plot(losses, label='loss')
            plt.legend()
            plt.show()

    def sample(self, n_samples, device='cpu'):
        sample = self.model.sample(n_samples)[0].to(device)
        return sample


class GNFModel(Model):

    def __init__(self, dag, name=None, hypernet_nhs=[16, 16], nodenet_nh=16):
        
        self.name = 'G-NF' if name is None else name
        base = nf.distributions.base.DiagGaussian(len(dag.nodes), trainable=False)
        flows = []
        for node in dag.nodes:
            flows.append(HyperCouplingBlock(dag, node, hypernet_nhs=hypernet_nhs, nodenet_nh=nodenet_nh))
        self.model = nf.NormalizingFlow(base, flows)

    def fit(self, X, device, epochs=100, bs=None, lr=1e-3, milestones=[], gamma=0.1, verbose=True):

        if bs is None:
            bs = get_batch_size(X.shape[0])
        
        self.model.to(device)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
        dataset = TensorDataset(X.to(device))
        loader = DataLoader(dataset, batch_size=bs, shuffle=True)
        losses = []
        loop = tqdm(range(epochs)) if verbose else range(epochs)
        for epoch in loop:
            for x, in loader:
                optimizer.zero_grad()
                loss = self.model.forward_kld(x)
                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    optimizer.step()
            
            scheduler.step()
            losses.append(loss.item())
            if verbose:
                loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        
        if verbose:
            plt.plot(losses, label='loss')
            plt.legend()
            plt.show()

    def sample(self, n_samples, device='cpu'):
        sample = self.model.sample(n_samples)[0].to(device)
        return sample


class VAEModel(Model):

    def __init__(self, name=None, encoder_layers=[2, 32, 32, 16], decoder_layers=[10, 32, 32, 2], beta=None):

        self.name = 'VAE' if name is None else name
        self.model = VAE(encoder_layers=encoder_layers, decoder_layers=decoder_layers, beta=beta)

    def fit(self, X, device, epochs=100, bs=None, lr=1e-3, milestones=[], gamma=0.1, verbose=True):

        if bs is None:
            bs = get_batch_size(X.shape[0])
        
        self.model.to(device)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
        dataset = TensorDataset(X.to(device))
        loader = DataLoader(dataset, batch_size=bs, shuffle=True)
        losses = []
        loop = tqdm(range(epochs)) if verbose else range(epochs)
        for epoch in loop:
            for x, in loader:
                optimizer.zero_grad()
                recon_x, mu, logvar = self.model(x)
                loss = self.model.loss_function(recon_x, x, mu, logvar)   
                loss.backward()
                optimizer.step()
            
            scheduler.step()
            losses.append(loss.item())
            if verbose:
                loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        
        if verbose:
            plt.plot(losses, label='loss')
            plt.legend()
            plt.show()

    def sample(self, n_samples, device='cpu'):
        sample = self.model.sample(n_samples).to(device)
        return sample



class GMMModel(Model):

    def __init__(self, name=None, n_components=10, **kwargs):
        self.name = 'GMM' if name is None else name
        self.model = GaussianMixture(n_components=10, **kwargs)

    def fit(self, X, device):
        self.model.fit(X.to('cpu'))

    def sample(self, n_samples, device='cpu'):
        sample = torch.tensor(self.model.sample(n_samples)[0], dtype=torch.float, device=device) 
        return sample
