import torch
import torch.nn as nn
from ML_Models.base_model import BaseMLModel
# import typing as tp
from torch.optim import RMSprop, Adam, SGD
from sklearn.linear_model import LogisticRegression
from Tools.jackknife import *
from torch.nn.parameter import Parameter
# Implement Linear Regression (task="classification") and Logistic Regression (task="regression")


class Regression(BaseMLModel):
    def __init__(self, input_dim: int, task: str = 'classification',
                 weighted_model: bool = False, train_set_size=None,
                 l2_pen: float = 1.0, n_epochs: int = 5000, sgd: bool = False):
        
        super().__init__(weighted_model, train_set_size)
        
        # Auxiliary information
        self.input_dim = input_dim
        self.l2_pen = l2_pen
        self.n_epochs = n_epochs
        self.sgd_optim = sgd
        if task == "regression" or task == "classification":
            self.task = task
        else:
            raise ValueError(f"Illegal task {task}. Only classification and regression are supported.")
        self.train_set_size = train_set_size
        
        # Layers
        self.weights = Parameter(torch.randn(self.input_dim+1), requires_grad=False)
        #print([p for p in self.parameters()])

    def get_all_params(self):
        """ Return the concatenation of model weights and bias. Shape [1, D+1]. """
        return self.weights.reshape(1, -1)

    def fit(self, X: torch.tensor, y: torch.tensor):
        """ 
            Fit the regression model. 
            The linear regression model (task=regression) is fitted using the OLS equations.
            The logistic regression model (task=classification)
        """
        if self.task == "classification":
            if self.sgd_optim:
                self.weights.requires_grad_(True)
                optim = SGD(self.parameters(), lr=1e-2)
                best_loss = float("inf")
                eps_no_impr = 0
                for i in range(self.n_epochs):
                    optim.zero_grad()
                    loss = self.loss_objective(self.get_all_params(), X, y)
                    loss.backward()
                    optim.step()
                    #print(loss.detach())
                    if loss.detach() > best_loss:
                        eps_no_impr += 1
                        if eps_no_impr == 5:
                            break
                    else:
                        best_loss = loss.detach()
                        eps_no_impr = 0
                    i += 1
                if eps_no_impr < 5:
                    print("Warning: SGD has not converged.")
                self.weights.requires_grad_(False)
            else:
                y = y.long()
                lr = LogisticRegression(max_iter=2000)
                lr.fit(X.numpy(), y.numpy(), self.data_weights_vector if self.weighted_model else None)
                #print(lr.coef_, lr.intercept_)
                self.weights.data[:-1] = torch.tensor(lr.coef_, dtype=torch.float32).reshape(1, self.input_dim)
                self.weights.data[-1] = torch.tensor(lr.intercept_, dtype=torch.float32)
                #print(lr.score(X, y))
        elif self.task == "regression":
            # Add another dimension of ones to X.
            X_extend = torch.cat((X, torch.ones(len(X), 1)), dim=1)
            if self.weighted_model:
                X_extend = torch.sqrt(self.data_weights_vector).reshape(-1, 1)*X_extend
                y = torch.sqrt(self.data_weights_vector)*y
            # Apply the normal equations.
            beta = torch.linalg.inv(X_extend.t().matmul(X_extend)).matmul(X_extend.t()).matmul(y.unsqueeze(1)).flatten()
            #print(beta.shape)
            self.weights.data = beta
        else:
            raise ValueError("Unsupported task.")

    def forward(self, x: torch.tensor) -> torch.tensor:
        if self.task == 'classification':
            return torch.sigmoid(self.predict_with_logits(x))
        else:
            return self.predict_with_logits(x)
    
    def predict_with_logits(self, x: torch.tensor) -> torch.tensor:
        X_extend = torch.cat((x, torch.ones(len(x), 1)), dim=1)
        return torch.matmul(X_extend, self.weights.reshape(-1,1)).flatten()

    def predict_from_parameters(self, x: torch.tensor, parameters: torch.tensor) -> torch.tensor:
        """
            Predict logits using different model parameters (instead the ones stored with this object) supplied as inputs.
            param x: [B, D] inputs. 
            :param parameters: (C,[weight_matrix, bias_vector])
            :returns: [B, C]-matrix.
        """
        X_extend = torch.cat((x, torch.ones(len(x),1)), dim=1) #[B, D+1]
        #x_unsq = X_extend.unsqueeze(1).unsqueeze(3) # [C, D+1]
        res = X_extend.matmul(parameters.t())
        return res
    
    def loss_objective(self, parameters: torch.tensor, X: torch.tensor, y: torch.tensor, data_weights = None):
        if self.task == "classification":
            pred = self.predict_from_parameters(X, parameters).flatten()
            yb = 2*y-1 # convert labels to +-1 
            #print(yb)
            #print(y)
            #print(torch.log(torch.exp(-yb*pred)+1))

            #print(torch.sum(-torch.log((1-y)+yb*pred)))

            # case y = 0, yb = -1: -> 1-pred
            # case y = 1, yb = +1: -> pred
            if data_weights is not None:
                loss = torch.sum(data_weights*torch.log(torch.exp(-yb*pred)+1))
            else:
                loss = torch.sum(torch.log(torch.exp(-yb*pred)+1))

            #pred = torch.sigmoid(self.predict_from_parameters(X, parameters).flatten()) 
            #pred = torch.cat((1-pred.reshape(-1,1), pred.reshape(-1,1)), dim=1)
            #loss = nn.BCELoss(reduction="sum").forward(pred.reshape(-1), y.reshape(-1).float())
            weights = parameters[:,:-1]
            param_norm = 0.5*torch.sum(weights.pow(2).flatten())
            #print("Param norm: ", param_norm)
            return self.l2_pen*loss + param_norm
        else:
            pred = self.forward(X)
            if data_weights:
                return torch.sum((pred-y).pow(2)*data_weights)
            else:
                return torch.sum((pred-y).pow(2))

    def parameter_change_under_removal(self, X: torch.tensor, y: torch.tensor, ind: torch.tensor = ...):
        """ Compute the new paramters under removal of data_points.
            The analytical solution for linear regression is taken from D. Cook: 
            "Detection of Influential Observation in Linear Regression" (1977)
            https://www.jstor.org/stable/1268249
            The change is given by ((X^T^*X)^-1 x_i)/(1-vi), with vi = 
            For logistic regression, the jackknife approximation is used.
        """
        if self.task == "regression":
            res = self.forward(X[ind, :]) - y[ind] # residuals.
            X_extend = torch.cat((X, torch.ones(len(X),1)), dim=1)
            kernel_mat = torch.linalg.inv(X_extend.t().matmul(X_extend))
            kernel_mat_xi = kernel_mat.matmul(X_extend[ind, :].t()) # shape [D+1, B]
            vi = torch.sum(X_extend[ind, :].t()*kernel_mat_xi, dim=0)
            diff = kernel_mat_xi*(res/(1-vi))
            diff = diff.t()
            return diff
        else: # classification, use_jackknife approximation.
            opt_weights = self.get_all_params()
            #print(opt_weights.shape)
            jackknife_obj = self.loss_objective_for_jackknife(X, y) # function of (data weights, theta)
            j_mat = jackknife_compute_jacobians(opt_weights, jackknife_obj, self.data_weights_vector, additional_params=None).squeeze(0) # [Len params, #data weights]
            return j_mat[:,ind].t()

    def compute_parameters_from_data_weights(self, data_weights: torch.tensor, X: torch.tensor, y: torch.tensor):
        """ Compute model weight change under soft removal of specific points.
            Note: this function should be differentiable w.r.t. data_weights to be useful in 
            end-to-end gradient descent optimization.
            For linear regression, this function will return a (differentiable) taylor approximation of the
            parameter vector using the current model parameters and data_weights as the center point, i.e.,
            parameters(data_weights) = self.parameters + J*(data_weights- self.data_weights)

        """
        assert self.weighted_model
        if self.task == "regression":
            X_extend = torch.cat((X, torch.ones(len(X),1)), dim=1)
            X_extend = torch.sqrt(data_weights).reshape(-1,1)*X_extend
            y = torch.sqrt(data_weights)*y
            # Apply the normal equations.
            beta = torch.linalg.inv(X_extend.t().matmul(X_extend)).matmul(X_extend.t()).matmul(y.unsqueeze(1)).flatten()
            #print(beta.shape)
            return beta
        else: # classification
            # Apply the infinitissimal jackknife approximation.
            opt_weights = self.get_all_params()
            #print(opt_weights.shape)
            jackknife_obj = self.loss_objective_for_jackknife(X, y) # function of (data weights, theta)
            j_mat = jackknife_compute_jacobians(opt_weights, jackknife_obj, self.data_weights_vector, additional_params=None).squeeze(0) # [Len params, #data weights]
            return self.get_all_params().detach() - j_mat.matmul(data_weights - self.data_weights_vector.detach())
