import numpy as np

import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.distributions as D

from rampwf.utils import BaseGenerativeRegressor

from mbrltools.pytorch_utils import train, EarlyStopping
from rampwf.hyperopt import Hyperparameter



torch.manual_seed(7)

# RAMP START HYPERPARAMETERS
BATCH_SIZE = Hyperparameter(
    dtype='int', default=512, values=[128, 256, 512])
LAYER_SIZE = Hyperparameter(
    dtype='int', default=50, values=[50, 100, 200])
LR = Hyperparameter(
    dtype='float', default=1e-4, values=[1e-4, 1e-3, 1e-2])
N_LAYERS_COMMON = Hyperparameter(
    dtype='int', default=2, values=[1, 2, 3, 4])
# RAMP END HYPERPARAMETERS


n_epochs = 1
# LR = 2e-3
# N_LAYERS_COMMON = 2
# LAYER_SIZE = 200
# BATCH_SIZE = 512
VALIDATION_FRACTION = 0.1
N_GAUSSIANS=1
DROP_FIRST = 0
DROP_REPEATED = 1e-1

CONST = np.sqrt(2 * np.pi)

SCALE_TARGET = False
USE_DIFF = False

PARAMS_EARLY_STOPPING = {"patience" : 40, "min_delta" : 1e-2}


def gauss_pdf(x, mean, sd):
    ret = torch.exp(-0.5 * ((x - mean) / sd) ** 2) / (sd * CONST)
    return ret


class CustomLoss:
    def __call__(self, y_true, y_pred):

        mus = y_pred[:len(y_true)]
        sigmas = y_pred[len(y_true):len(y_true) * 2]
        w = y_pred[2 * len(y_true):]

        # the torch distributions expects (batch_size, n_gaussians, observation_dim)
        mus = mus.reshape((len(y_true), int(N_GAUSSIANS), -1))
        sigmas = sigmas.reshape((len(y_true), int(N_GAUSSIANS), -1))

        # one weight per component
        w_components = torch.mean(w, dim=1)

        mix = D.Categorical(w_components)
        comp = D.Independent(D.Normal(
            mus, sigmas), 1)
        gmm = D.MixtureSameFamily(mix, comp)

        likelihood = gmm.log_prob(y_true)

        nll = -likelihood.mean()
        return nll



class GenerativeRegressor(BaseGenerativeRegressor):
    def __init__(self, max_dists, target_dim):
        self.max_dists = max_dists
        self.decomposition = None
        self.target_dim = None

    def fit(self, X_in, y_in):

        self.model = SimpleBinnedNoBounds(int(N_GAUSSIANS), X_in.shape[1],
                                          y_in.shape[1])
        dataset = torch.utils.data.TensorDataset(
            torch.Tensor(X_in), torch.Tensor(y_in))
        optimizer = optim.Adam(
            self.model.parameters(), lr=float(LR), amsgrad=True)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', factor=0.1, patience=30, cooldown=20,
            min_lr=1e-7, verbose=True)
        loss = CustomLoss()

        earlystopping = EarlyStopping(**PARAMS_EARLY_STOPPING)

        self.model, _ = train(
            self.model, dataset, validation_fraction=VALIDATION_FRACTION,
            optimizer=optimizer, scheduler=scheduler,
            n_epochs=n_epochs, batch_size=int(BATCH_SIZE), loss_fn=loss, verbose=True,
            early_stopping=earlystopping, return_best_model=True, disable_cuda=True,
            tensorboard_path='/tmp/dmdn')

    def predict(self, X):
        # we use predict sequentially in RL and there is no need to compute
        # model.eval() each time if the model is already in eval mode
        if self.model.training:
            self.model.eval()

        with torch.no_grad():
            X = torch.Tensor(X)
            n_samples = X.shape[0]
            y_pred = self.model(X)

            mus = y_pred[:n_samples].detach().numpy()
            sigmas = y_pred[n_samples:2*n_samples].detach().numpy()
            weights = y_pred[2*n_samples:].detach().numpy()

        # We put each mu next to its sigma
        params = np.empty((n_samples, mus.shape[1] * mus.shape[2] * 2))
        params[:, 0::2] = mus.reshape(n_samples, -1)
        params[:, 1::2] = sigmas.reshape(n_samples, -1)
        types = ['norm'] * N_GAUSSIANS * mus.shape[1]

        return weights.reshape(n_samples, -1), types, params


class OutputModule(nn.Module):
    def __init__(self, n_sigmas):
        super(OutputModule, self).__init__()
        output_size_sigma = n_sigmas
        output_size_mus = n_sigmas
        layer_size = int(LAYER_SIZE)
        self.mu = nn.Sequential(
            nn.Linear(layer_size, layer_size),
            torch.nn.Tanh(),
            nn.Linear(layer_size, output_size_mus)
        )

        self.sigma = nn.Sequential(
            nn.Linear(layer_size, layer_size),
            torch.nn.Tanh(),
            nn.Linear(layer_size, output_size_sigma),
        )

    def forward(self, x):
        mu = self.mu(x)
        sigma = self.sigma(x)
        sigma = torch.exp(sigma)
        return mu, sigma


class SimpleBinnedNoBounds(nn.Module):
    def __init__(self, n_sigmas, input_size, nb_y):
        super(SimpleBinnedNoBounds, self).__init__()
        n_layers_common = int(N_LAYERS_COMMON)
        layer_size = int(LAYER_SIZE)

        self.linear0 = nn.Linear(input_size, layer_size)
        self.act0 = nn.Tanh()
        self.drop = nn.Dropout(p=DROP_FIRST)

        self.common_block = nn.Sequential()
        for i in range(n_layers_common):
            self.common_block.add_module(
                f'layer{i + 1}-lin', nn.Linear(layer_size, layer_size))
            self.common_block.add_module(
                f'layer{i + 1}-bn', nn.BatchNorm1d(layer_size))
            self.common_block.add_module(f"layer{i + 1}-act", nn.Tanh())
            if i % 2 == 0:
                self.common_block.add_module(
                    f'layer{i + 1}-drop', nn.Dropout(p=DROP_REPEATED))

        self.detached_blocks = nn.ModuleList()
        for j in range(nb_y):
            self.detached_blocks.append(OutputModule(n_sigmas))

        self.w = nn.Sequential(
            nn.Linear(layer_size, n_sigmas),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.linear0(x)
        x = self.act0(x)
        raw = self.drop(x)
        x = self.common_block(raw)
        x = x + raw

        mus = []
        sigmas = []
        for block in self.detached_blocks:
            mu, sigma = block(x)
            mus.append(mu)
            sigmas.append(sigma)

        mu = torch.stack(mus, dim=1)
        sigma = torch.stack(sigmas, dim=1)

        w = self.w(x)
        w = torch.stack(mu.shape[1]*[w], dim=1)
        return torch.cat([mu, sigma, w], dim=0)
