"""NeuMiss + MLP"""
import numpy as np
from sklearn.base import BaseEstimator

import torch
from torch import Tensor, nn
from torch.nn import Linear, Sequential, ReLU
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from ..misc.pytorchtools import EarlyStopping
from ..networks.neumiss import NeuMissBlock


class NeuMissMLP(nn.Module):
    """A NeuMiss block followed by an MLP.

    Args:
        n_features : dimension of inputs.
        neumiss_depth : number of layers in the NeuMiss block.
        mlp_depth : number of hidden layers in the MLP.
        mlp_width : width of the MLP. If None take mlp_width=n_features. Default: None.
        dtype : Pytorch dtype for the parameters. Default: torch.float.
    """

    def __init__(
        self,
        n_features: int,
        neumiss_depth: int,
        mlp_depth: int,
        mlp_width: int = None,
        dtype=torch.float,
    ):
        super().__init__()
        self.n_features = n_features
        self.neumiss_depth = neumiss_depth
        self.mlp_depth = mlp_depth
        self.dtype = dtype
        mlp_width = n_features if mlp_width is None else mlp_width
        self.mlp_width = mlp_width

        b = int(mlp_depth >= 1)
        last_layer_width = mlp_width if b else n_features
        self.layers = Sequential(
            NeuMissBlock(n_features, neumiss_depth, dtype),
            *[Linear(n_features, mlp_width, dtype=dtype), ReLU()] * b,
            *[Linear(mlp_width, mlp_width, dtype=dtype), ReLU()] * b * (mlp_depth - 1),
            *[Linear(last_layer_width, 1, dtype=dtype)],
        )

    def forward(self, x: Tensor) -> Tensor:
        out = self.layers(x)
        return out.squeeze()


class Neumiss(BaseEstimator):
    """Predict with a Neumiss block followed by an MLP

    Parameters
    ----------
    Args:
        depth: the depth of the NeuMiss block.
        n_epochs: the maximum number of epochs.
        batch_size: the batch size.
        lr: the learning rate.
        weight_decay: the weight decay parameter.
        early_stopping: if True, early stopping is used based on the validaton set.
        optimizer: one of `sgd`or `adam`.
        mlp_depth: the depth of the MLP stacked on top of the Neumann iterations.
        width_factor: the width of the MLP stacked on top of the NeuMiss block is calculated
            as width_factor times n_features.
        add_mask: if True, the mask is concatenated to the output of the NeuMiss block.
        verbose: flag to print detailed information about training to the console.
    """

    def __init__(
        self,
        depth: int,
        n_epochs: int = 1000,
        batch_size: int = 100,
        lr: float = 0.01,
        weight_decay: float = 1e-4,
        early_stopping: bool = False,
        optimizer: str = "sgd",
        mlp_depth: int = 0,
        width_factor: int = 1,
        add_mask: bool = False,
        verbose: bool = False,
    ):
        self.depth = depth
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.weight_decay = weight_decay
        self.early_stop = early_stopping
        self.optimizer = optimizer
        self.mlp_depth = mlp_depth
        self.width_factor = width_factor
        self.add_mask = add_mask
        self.verbose = verbose

        self.r2_train = []
        self.mse_train = []
        self.r2_val = []
        self.mse_val = []

    def fit(self, X, y, X_val=None, y_val=None):

        n_samples, n_features = X.shape

        X = torch.as_tensor(X, dtype=torch.double)
        y = torch.as_tensor(y, dtype=torch.double)

        if X_val is not None:
            X_val = torch.as_tensor(X_val, dtype=torch.double)
            y_val = torch.as_tensor(y_val, dtype=torch.double)

        self.net = NeuMissMLP(
            n_features=n_features,
            neumiss_depth=self.depth,
            mlp_depth=self.mlp_depth,
            mlp_width=self.width_factor,
            dtype=torch.double,
        )

        if len(list(self.net.parameters())) > 0:
            # Create parameter groups
            group_wd = []
            group_no_wd = []
            for name, param in self.net.named_parameters():
                if name in ["mu", "b"]:
                    group_no_wd.append(param)
                else:
                    group_wd.append(param)

            if self.optimizer == "sgd":
                self.optimizer = optim.SGD(
                    [
                        {"params": group_wd, "weight_decay": self.weight_decay},
                        {"params": group_no_wd, "weight_decay": 0},
                    ],
                    lr=self.lr,
                )
            elif self.optimizer == "adam":
                self.optimizer = optim.Adam(
                    [
                        {"params": group_wd, "weight_decay": self.weight_decay},
                        {"params": group_no_wd, "weight_decay": 0},
                    ],
                    lr=self.lr,
                )

            self.scheduler = ReduceLROnPlateau(
                self.optimizer, mode="min", factor=0.2, patience=10, threshold=1e-4
            )

        if self.early_stop and X_val is not None:
            early_stopping = EarlyStopping(verbose=self.verbose)

        criterion = nn.MSELoss()

        # Train the network
        for i_epoch in range(self.n_epochs):
            if self.verbose:
                print("epoch nb {}".format(i_epoch))

            # Shuffle tensors to have different batches at each epoch
            ind = torch.randperm(n_samples)
            X = X[ind]
            y = y[ind]

            xx = torch.split(X, split_size_or_sections=self.batch_size, dim=0)
            yy = torch.split(y, split_size_or_sections=self.batch_size, dim=0)

            param_group = self.optimizer.param_groups[0]
            lr = param_group["lr"]
            if self.verbose:
                print("Current learning rate is: {}".format(lr))
            if lr < 1e-4:
                break

            for bx, by in zip(xx, yy):

                self.optimizer.zero_grad()

                y_hat = self.net(bx)

                loss = criterion(y_hat, by)
                loss.backward()

                # Take gradient step
                self.optimizer.step()

            # Evaluate the train loss
            with torch.no_grad():
                y_hat = self.net(X)
                loss = criterion(y_hat, y)
                mse = loss.item()
                self.mse_train.append(mse)

                var = ((y - y.mean()) ** 2).mean()
                r2 = 1 - mse / var
                self.r2_train.append(r2)

                if self.verbose:
                    print("Train loss - r2: {}, mse: {}".format(r2, mse))

            # Evaluate the validation loss
            if X_val is not None:
                with torch.no_grad():
                    y_hat = self.net(X_val)
                    loss_val = criterion(y_hat, y_val)
                    mse_val = loss_val.item()
                    self.mse_val.append(mse_val)

                    var = ((y_val - y_val.mean()) ** 2).mean()
                    r2_val = 1 - mse_val / var
                    self.r2_val.append(r2_val)
                    if self.verbose:
                        print("Validation loss is: {}".format(r2_val))

                if self.early_stop:
                    early_stopping(mse_val, self.net)
                    if early_stopping.early_stop:
                        if self.verbose:
                            print("Early stopping")
                        break

                self.scheduler.step(mse_val)

        # load the last checkpoint with the best model
        if self.early_stop and early_stopping.early_stop:
            self.net.load_state_dict(early_stopping.checkpoint)

    def predict(self, X):
        """Predict the outcome from partially-observed data.

        Args:
            X: original (n, d) covariates w/ missingness

        Returns:
            predicted outcomes (n, d)
        """
        X = torch.as_tensor(X, dtype=torch.double)

        with torch.no_grad():
            y_hat = self.net(X)

        return np.array(y_hat)
