from multiprocessing import reduction
import torch
from torch import nn
from ML_Models.base_model import BaseMLModel
from torch.optim import Adam, RMSprop, SGD
from Tools.jackknife import jackknife_compute_jacobians
from tqdm import tqdm
from sklearn.neural_network import MLPClassifier, MLPRegressor
class ANN(BaseMLModel):
    """ A simple neural network with a single hidden layer. """
    def __init__(self, input_dim: int, hidden_layer: int = 100,
                 num_of_classes: int = 1, task: str = 'classification',
                 weighted_model: bool = False, train_set_size=None, n_epochs =1000, base_lr = 1e-3, optim ="adam",
                 l2_reg_lambda = 0.01, activation="sigmoid", fit_scipy=False):
        """ 
            hidden_layer: Number of units in the hidden layer.
            n_epochs: Number of epochs for training+
            base_lr: learning rate that the optimizer is started with. The LR is decreased if no progress is made 
                anymore.
            optim: adam, rmsprop, sgd
            l2_reg_lambda: L2-weight regularization strength
            activation: sigmoid, relu, celu, tanh activation function to use. celu is not supported for fit_scipy
            fit_scipy: Use scipy optimization to fit the NN. Otherwise SGD is used by default.
        """
        super().__init__(weighted_model, train_set_size)
        
        # Layers
        self.input_dims = input_dim
        self.hidden_layer = hidden_layer
        self.num_of_classes = num_of_classes
        self.n_epochs = n_epochs
        self.input1 = nn.Linear(input_dim, hidden_layer)
        self.input2 = nn.Linear(hidden_layer, num_of_classes)
        self.task = task
        self.base_lr = base_lr
        self.optim = optim
        self.l2_reg_lambda = l2_reg_lambda
        self.fit_scipy = fit_scipy
        # Activations
        self.activation_str = activation
        if activation == "sigmoid":
            self.activation = nn.Sigmoid()
            self.activation_str = "logistic"
        elif activation == "relu":
            self.activation == nn.ReLU()
        elif activation == "celu":
            self.activation == nn.CELU()
        elif activation == "tanh":
            self.activation == nn.Tanh()

        self.sigmoid = nn.Sigmoid()
    
    def fit(self, X: torch.tensor, y: torch.tensor):
        """ 
            Fit the ANN.
            labels are in a 0,1 binary format for classification and continous for regression.
        """
        #print([p.shape for p in self.parameters()])
        if self.fit_scipy:
            self._fit_scipy(X, y)
        else:
            self._fit_sgd(X,y)

    def _fit_sgd(self, X: torch.tensor, y: torch.tensor):
        """ Fit with stochastic Gradiend Descent. """
        current_lr = self.base_lr
        if self.optim == "adam":
            optim = Adam(self.parameters(), lr=current_lr)
        elif self.optim =="sgd":
            optim = Adam(self.parameters(), lr=current_lr)
        else:
            raise ValueError("Currently adam and sgd are supported.")

        best_loss = float("inf")
        eps_no_impr = 0
        epoch_cnt = 0
        for i in tqdm(range(self.n_epochs), position=0):
            optim.zero_grad()
            loss = self.loss_objective(self.get_all_params(), X, y, self.data_weights_vector)
            loss.backward()
            optim.step()
            #print(loss.detach())
            if loss.detach() > best_loss:
                eps_no_impr += 1
                if eps_no_impr == 3: # Reschedule lr
                    current_lr = 0.1*current_lr
                    optim.lr = current_lr
                    #print("Switching lr to ", current_lr)
                    if current_lr <= 1e-7:
                        break
            else:
                best_loss = loss.detach()
                eps_no_impr = 0
            epoch_cnt += 1
        if current_lr > 1e-7:
            print("Warning: Max iterations reached. Check fit.")
        else:
            print(f"Converged after {epoch_cnt} epochs.")

    def _fit_scipy(self, X: torch.tensor, y: torch.tensor):
        num_inputs = len(X)
        if self.task == "classification":
            sklearn_ann = MLPClassifier(hidden_layer_sizes=(self.hidden_layer,), solver='lbfgs',
                    activation=self.activation_str, alpha=self.l2_reg_lambda, max_iter=5000, random_state=0, tol=1e-8, max_fun=25000)
            sklearn_ann.fit(X.numpy(), y.numpy())
        elif self.task == "regression":
            sklearn_ann = MLPRegressor(hidden_layer_sizes=(self.hidden_layer,), solver='lbfgs',
                    activation=self.activation_str, alpha=self.l2_reg_lambda, max_iter=5000, random_state=0, tol=1e-8, max_fun=25000)
            sklearn_ann.fit(X.numpy(), y.numpy())
        #print(sklearn_ann.coefs_, sklearn_ann.intercepts_)
        self.input1.weight.data = torch.tensor(sklearn_ann.coefs_[0], dtype=torch.float32).t()
        self.input2.weight.data = torch.tensor(sklearn_ann.coefs_[1], dtype=torch.float32).t()
        self.input1.bias.data = torch.tensor(sklearn_ann.intercepts_[0], dtype=torch.float32).flatten()
        self.input2.bias.data = torch.tensor(sklearn_ann.intercepts_[1], dtype=torch.float32).flatten()
        print("Sklearn loss:", len(X)*sklearn_ann.loss_)
        #print(sklearn_ann._backprop()
        #inputtensor = torch.randn(5,2) # Check predictions
        #print(torch.sigmoid(self.predict_with_logits(inputtensor)))
        #print(sklearn_ann.predict_proba(inputtensor.numpy()))

    def get_all_params(self):
        """ Return concatenated parameters. """
        return torch.cat((self.input1.weight.reshape(1,-1),
            self.input1.bias.reshape(1,-1),
            self.input2.weight.reshape(1,-1),
            self.input2.bias.reshape(1,-1)), dim=1)

    def predict_with_logits(self, x: torch.tensor) -> torch.tensor:
        """ 
            The prediction function.
        """
        output = self.input1(x)
        output = self.activation(output)
        output = self.input2(output)
        return output.flatten()

    def predict_from_parameters(self, x: torch.tensor, parameters: torch.tensor) -> torch.tensor:
        """
            Predict logits using different model parameters (instead the ones stored with this object) supplied as inputs.
            :param x: [B, D] inputs. 
            :param parameters: (C, num_params)
            :returns: [B, C]-matrix.
        """
        num_inputs = len(x)
        num_params = len(parameters)
        # first split up parameters again in their respective parts.
        # The intermediate results will have shape [B, C, D, 1]
        i1weight = parameters[:, :self.input_dims*self.hidden_layer].reshape(1, num_params, self.hidden_layer, self.input_dims)
        offset = self.input_dims*self.hidden_layer
        i1bias = parameters[:, offset:offset+self.hidden_layer].reshape(1, num_params, self.hidden_layer, 1)
        offset = offset + self.hidden_layer
        i2weight = parameters[:, offset:offset+self.hidden_layer*self.num_of_classes].reshape(1, num_params, self.num_of_classes, self.hidden_layer)
        offset = offset + self.hidden_layer*self.num_of_classes
        i2bias = parameters[:, offset:offset+self.num_of_classes].reshape(1, num_params, self.num_of_classes, 1)
        #print(i1weight.shape, i1bias.shape, i2weight.shape, i2bias.shape)
        #print(x.unsqueeze(1).unsqueeze(-1).shape)
        layer1 = self.activation(i1weight.matmul(x.unsqueeze(1).unsqueeze(-1)) + i1bias)
        #print(layer1[:,0,:,0])
        layer2 = i2weight.matmul(layer1) + i2bias
        #print(layer2.shape)
        return layer2.squeeze(-1).squeeze(-1)

    def forward(self, x: torch.tensor) -> torch.tensor:
        output = self.input1(x)
        output = self.activation(output)
        output = self.input2(output)
        if self.task == 'classification':
            output = self.sigmoid(output)
        return output


    def loss_objective(self, parameters: torch.tensor, X: torch.tensor, y: torch.tensor, data_weights = None):
        """ MSE Loss for regression, NLL loss for classification. """
        preds = self.predict_from_parameters(X, parameters)
        i1weight = parameters[:, :self.input_dims*self.hidden_layer]
        offset = self.input_dims*self.hidden_layer + self.hidden_layer
        i2weight = parameters[:, offset:offset+self.hidden_layer*self.num_of_classes]
        weight_loss = 0.5*self.l2_reg_lambda*(torch.sum(i1weight.pow(2)) + torch.sum(i2weight.pow(2)))
        if data_weights is None:
            dw_use = torch.ones(len(X))
        else:
            dw_use = data_weights
        if self.task == "regression":
            return 0.5*torch.sum(dw_use*torch.pow(preds.flatten()-y,2)) + weight_loss
        elif self.task == "classification":
            logits = preds.reshape(-1) # torch.cat((-preds.reshape(-1,1), preds.reshape(-1,1)), dim=1)
            loss = torch.nn.BCEWithLogitsLoss(reduction="none").forward(logits, y.float())
            #print(0.5*self.l2_reg_lambda*(torch.sum(i1weight.pow(2)) + torch.sum(i2weight.pow(2)))/len(X))
            return torch.sum(dw_use*loss) + weight_loss
    
    def parameter_change_under_removal(self, X: torch.tensor, y: torch.tensor, ind: torch.tensor = ...):
        """ 
            The jackknife approximation is used.
        """
        assert self.weighted_model
        opt_weights = self.get_all_params()
        #print(opt_weights.shape)
        jackknife_obj = self.loss_objective_for_jackknife(X, y) # function of (data weights, theta)
        j_mat = jackknife_compute_jacobians(opt_weights, jackknife_obj, self.data_weights_vector, additional_params=None).squeeze(0) # [Len params, #data weights]
        return j_mat[:,ind].t()

    def compute_parameters_from_data_weights(self, data_weights: torch.tensor, X: torch.tensor, y: torch.tensor):
        """ Compute model weight change under soft removal of specific points.
            Note: this function should be differentiable w.r.t. data_weights to be useful in 
            end-to-end gradient descent optimization.
            For linear regression, this function will return a (differentiable) taylor approximation of the
            parameter vector using the current model parameters and data_weights as the center point, i.e.,
            parameters(data_weights) = self.parameters + J*(data_weights- self.data_weights)

        """
        assert self.weighted_model
        # Apply the infinitissimal jackknife approximation.
        opt_weights = self.get_all_params()
        #print(opt_weights.shape)
        jackknife_obj = self.loss_objective_for_jackknife(X, y) # function of (data weights, theta)
        j_mat = jackknife_compute_jacobians(opt_weights, jackknife_obj, self.data_weights_vector, additional_params=None).squeeze(0) # [Len params, #data weights]
        return self.get_all_params().detach() + j_mat.matmul(data_weights - self.data_weights_vector.detach())
