import numpy as np
import scipy.optimize as sco
from scipy.special import logit
from sklearn.metrics import accuracy_score
import time


class OptimizationTask:
    def __init__(self, n_honest, learning_rate=None, *args, **kwargs):
        self.n_honest = n_honest
        self.learning_rate = learning_rate
        self.dim = None

    def variance(self, x_honest):
        return np.sum((x_honest - x_honest.mean(axis=0)[None,:])**2)/x_honest.shape[0]

    def worker_wise_error(self, x_honest):
        raise NotImplementedError
    
    def mean_error(self, x_honest):
        raise NotImplementedError 
    
    def mean_accuracy(self, x_honest):
        return None

    def optimization_loss(self, x_honest, subsampling=0):
        raise NotImplementedError
    
    def loss_train(self, x_honest, subsampling=0):
        raise NotImplementedError
    
    def loss_test(self, x_honest, subsampling=0):
        raise NotImplementedError

    def local_cost_functions(self, x_honest_i, i_honest):
        pass

    def local_gradient(self, x_honest_i, i_honest):
        pass

    def gradient(self, x_honest):
        jac = np.zeros_like(x_honest, np.float64)
        for i_honest in range(self.n_honest):
            jac[i_honest, :] = self.local_gradient(x_honest[i_honest,:], i_honest)
        return jac

    def dual_gradient(self, y_honest, x_guess=None):
        if x_guess is None:
            x_guess = np.zeros_like(y_honest)

        x_new_honest = np.zeros_like(y_honest)

        for i_honest in range(self.n_honest):
            minus_dual_cost_i = lambda x_honest_i: - (np.sum(x_honest_i * y_honest[i_honest, :]) -
                                              self.local_cost_functions(x_honest_i, i_honest))

            minus_dual_cost_gradient_i = lambda x_honest_i: - (y_honest[i_honest, :] - self.local_gradient(x_honest_i, i_honest))

            # t = time.time()
            optim_result = sco.minimize(minus_dual_cost_i, x_guess[i_honest,:], method='L-BFGS-B',
                                        jac=minus_dual_cost_gradient_i, options={'maxiter':10})
            # print(f"lbfgs:{time.time() - t:.2}")


            x_new_honest[i_honest, :] = optim_result.x

        return x_new_honest

    def gradient_descent_step(self, x_honest):
        return x_honest - self.learning_rate * self.gradient(x_honest)




class AverageConsensus(OptimizationTask):
    def __init__(self, n_honest, x_honest_init, learning_rate=0, *args, **kwargs):
        super().__init__(n_honest, learning_rate, *args, **kwargs)
        self.x_honest_init = x_honest_init
        self.dim = x_honest_init.shape[1]

    
    
    # def error(self, x_honest):
    #    return 0.5 * np.sum((x_honest - np.mean(self.x_honest_init, axis=0)[None,:])**2, axis=1)
    
    def local_cost_function(self, x_honest_i, i_honest):
        return 0.5 * np.sum((x_honest_i - self.x_honest_init[i_honest, :]) ** 2)

    def local_gradient(self, x_honest_i, i_honest):
        return (x_honest_i - self.x_honest_init[i_honest, :])

    def gradient(self, x_honest):
        return x_honest - self.x_honest_init

    def dual_gradient(self, y_honest, x_guess=None):
        return y_honest + self.x_honest_init

    def gradient_descent_step(self, x_honest):
        return x_honest

    def optimization_loss(self, x_honest, subsampling=0):
        return np.sum((np.mean(self.x_honest_init,axis=0)[None,:] - x_honest)**2)/x_honest.shape[0]
    
    def loss_train(self, x_honest, subsampling=0):
        return self.optimization_loss(x_honest,subsampling=subsampling) 

    def mean_error(self, x_honest):
        return self.optimization_loss(x_honest) 
    
    def loss_test(self, x_honest, subsampling=0):
        return None
    
    


    


class LinearRegressionClassif(OptimizationTask):
    def __init__(self, n_honest, labels_train, covariate_train, labels_test, covariate_test, learning_rate=None,
                 ridge_penalty=0., *args, **kwargs):
        """
        Args:
            n_honest: int
            labels_train: list of n_honest (subsample_size) ndarray
            covariate_train: list of n_honest (subsample_size, dim) ndarray
            labels_test : (test_size) ndarray
            covariate_test : (test_size, dim) ndarray
        """

        super().__init__(n_honest, learning_rate, *args, **kwargs)
        self.labels_train = labels_train
        self.covariate_train = covariate_train
        self.labels_test = labels_test
        self.covariate_test = covariate_test
        self.dim = covariate_train[0].shape[1]
        self.ridge_penalty = ridge_penalty
        self.train_size = sum(self.labels_train[i].shape[0] for i in range(self.n_honest))
        self.test_size = self.labels_test.shape[0]

        self.moment_2_data = np.empty((self.n_honest, self.dim, self.dim))
        for i_honest in range(self.n_honest):
            self.moment_2_data[i_honest,...] = (self.covariate_train[i_honest].T @ self.covariate_train[i_honest] 
            / self.covariate_train[i_honest].shape[0])


        self.covariate_inverse = np.empty((self.n_honest, self.dim, self.dim))
        for i_honest in range(self.n_honest):
            self.covariate_inverse[i_honest,...]=(
                np.linalg.inv(self.moment_2_data[i_honest,...]
                              + self.ridge_penalty * np.eye(self.dim))
            )

        self.label_projection = np.empty((self.n_honest, self.dim))
        for i_honest in range(self.n_honest):
            self.label_projection[i_honest,...] = (
                self.covariate_train[i_honest].T @ self.labels_train[i_honest] 
                / self.covariate_train[i_honest].shape[0]
            )

    def accuracy_honest(self, x_honest):
        labels_pred = x_honest @ self.covariate_test.T
        labels_pred_binary = 2 * (labels_pred >= 0.).astype(int) - 1

        accuracy_honest = np.zeros((x_honest.shape[0],))
        for i_honest in range(self.n_honest):
            accuracy_honest[i_honest] = accuracy_score(self.labels_test, labels_pred_binary[i_honest])

        return accuracy_honest

    def mean_accuracy(self, x_honest):
        return self.error_honest(x_honest).mean()
    
    def mean_error(self, x_honest):
        return 1 - self.mean_accuracy(x_honest)

    def optimization_loss(self, x_honest, subsampling=0):
        if subsampling>0:
            subsampling_set = range(0, self.n_honest, self.n_honest//subsampling)
        else:
            subsampling_set = range(self.n_honest)
        loss = 0
        for j_honest in subsampling_set:  # we compute the loss of worker j
            x_honest_j = x_honest[j_honest]
            loss_worker_j = 0
            for batch_i in range(self.n_honest): # for each batch of the dataset
                loss_worker_j += 0.5 * np.mean((self.labels_train[batch_i] - self.covariate_train[batch_i] @ x_honest_j) ** 2) 

            loss += loss_worker_j/self.n_honest + (0.5 * np.sum(x_honest_j**2) * self.ridge_penalty)

        return loss / len(subsampling_set)  # return the average train loss on the subsampling set
    
    def loss_train(self, x_honest, subsampling=0):
        return self.optimization_loss(x_honest,subsampling=subsampling)
    
    def loss_test(self, x_honest, subsampling=0):
        loss = 0
        for j_honest in range(self.n_honest):  # we compute the test loss of worker j
            x_honest_j = x_honest[j_honest]
            
            loss += 0.5 * np.mean((self.labels_test - self.covariate_test @ x_honest_j) ** 2)

            loss += (0.5 * np.sum(x_honest_j**2) * self.ridge_penalty)

        return loss / (self.n_honest)  # return the average test loss

    def local_cost_functions(self, x_honest_i, i_honest):
        return 0.5 * np.mean(
            (self.labels_train[i_honest] - self.covariate_train[i_honest] @ x_honest_i) ** 2
        ) + 0.5 * np.sum(x_honest_i **2) * self.ridge_penalty

    def local_gradient(self, x_honest_i, i_honest):
        return self.moment_2_data[i_honest,...] @ x_honest_i - self.label_projection[i_honest,...] + self.ridge_penalty * x_honest_i

    def gradient(self, x_honest):
        return np.sum(self.moment_2_data * x_honest[:, None, :],axis=-1) - self.label_projection + self.ridge_penalty * x_honest

    def dual_gradient(self, y_honest, x_guess=None):
        """
        solved with cholesky in the Regression case
        """
        if x_guess is None:
            x_guess = np.zeros_like(y_honest)

        x_new_honest = np.sum(self.covariate_inverse *  (y_honest + self.label_projection)[:,None,:], axis=-1)
        return x_new_honest

class LogisticRegression(OptimizationTask):
    def __init__(self, n_honest, labels_train, covariate_train, labels_test, covariate_test, learning_rate=None,
                 ridge_penalty=0., *args , **kwargs):
        """
        Args:
            n_honest: int
            labels_train: list of n_honest (subsample_size) ndarray
            covariate_train: list of n_honest (subsample_size, dim) ndarray
            labels_test : (test_size) ndarray
            covariate_test : (test_size, dim) ndarray
        """

        super().__init__(n_honest, learning_rate, *args, **kwargs)
        self.labels_train = labels_train
        self.covariate_train = covariate_train
        self.labels_test = labels_test
        self.covariate_test = covariate_test
        self.dim = covariate_train[0].shape[1]
        self.ridge_penalty = ridge_penalty
        self.train_size = sum(self.labels_train[i].shape[0] for i in range(self.n_honest))


    def accuracy_honest(self, x_honest):
        labels_pred = x_honest @ self.covariate_test.T
        labels_pred_binary = 2 * (labels_pred >= 0.).astype(int) - 1

        accuracy_honest = np.zeros((x_honest.shape[0],))
        for i_honest in range(self.n_honest):
            accuracy_honest[i_honest] = accuracy_score(self.labels_test, labels_pred_binary[i_honest])

        return accuracy_honest

    def mean_accuracy(self, x_honest):
        return self.error_honest(x_honest).mean()

    def mean_error(self, x_honest):
        return 1 - self.mean_accuracy(x_honest)

    def optimization_loss(self, x_honest, subsampling=0):
        if subsampling>0:
            subsampling_set = range(0, self.n_honest, self.n_honest//subsampling)
        else:
            subsampling_set = range(self.n_honest)

        loss = 0
        for j_honest in subsampling_set: # compute the loss of worker j
            x_honest_j = x_honest[j_honest]
            loss_j = 0
            for batch_i in range(self.n_honest):  # on each batch of the data set

                u = (self.labels_train[batch_i] * (self.covariate_train[batch_i] @ x_honest_j))

                loss_j += np.mean(np.log(1 + np.exp(-u * (-u < 1e1)))) + np.sum((-u > 1e1) * (-u))

            loss += loss_j/self.n_honest + (0.5 * np.sum(x_honest_j ** 2) * self.ridge_penalty)

        return loss / len(subsampling_set) # return the average loss on workers
    

    def loss_train(self, x_honest, subsampling=0):
        return self.optimization_loss(x_honest,subsampling=subsampling)
    
    def loss_test(self, x_honest, subsampling=0):
        loss = 0
        for j_honest in range(self.n_honest): # compute the loss of worker j
            x_honest_j = x_honest[j_honest]
            loss_j = 0

            u = (self.labels_test * (self.covariate_test @ x_honest_j))

            loss_j += np.mean(np.log(1 + np.exp(-u * (-u < 1e1)))) + np.sum((-u > 1e1) * (-u))

            loss += loss_j + (0.5 * np.sum(x_honest_j ** 2) * self.ridge_penalty)

        return loss / self.n_honest # return the average loss on workers
    


    def local_cost_functions(self, x_honest_i, i_honest):
        with np.errstate(over='raise'):
            u = - (self.labels_train[i_honest] * (self.covariate_train[i_honest] @ x_honest_i))

            return (np.mean( np.log(1 + np.exp(u*(u<1e1))) + (u>1e1)*u)
                    + 0.5 * np.sum(x_honest_i ** 2) * self.ridge_penalty)

    def local_gradient(self, x_honest_i, i_honest):
        y, X, n_samples, strength = (self.labels_train[i_honest], self.covariate_train[i_honest],
                                     self.labels_train[i_honest].shape[0], self.ridge_penalty)
        u = - y * (X @ x_honest_i)
        v = y * ( np.exp(u < 1e1) / (1 + np.exp(u< 1e1)) + (u >= 1e1) )

        return - (X.T.dot(v)) / n_samples + strength * x_honest_i



