from copy import deepcopy
from datetime import datetime
from math import factorial
from random import choice, choices

import time
import numpy as np
import scipy.optimize
from scipy.optimize import minimize
from scipy.stats import norm, truncnorm, multivariate_normal

from models.Objective_Pert import ObjectivePert
from models.utils_AOP import Sample_Generator, Evaluator
from models.ERM import ERM

# import gurobipy as grb


class GradientPert(object):
    def __init__(self, exp_param, algo_param):
        self.x = None
        self.algo_param = algo_param
        self.exp_param = exp_param
        for k, v in algo_param.items():
            setattr(self, k, v)
        for k, v in exp_param.items():
            setattr(self, k, v)

        if exp_param['loss_func'] == 'l1':
            self.L = self.A_oper_norm_ub * np.sqrt(self.m)
        elif exp_param['loss_func'] == 'piecewise':
            self.L = self.A_oper_norm_ub * np.max(np.linalg.norm(self.pieces, ord=2, axis=1))
            self.piece_as, self.piece_bs = self.pieces[:, :-1], self.pieces[:, -1]

        self.sample_generator = Sample_Generator(self.x_true, self.A_mean, A_oper_norm_ub=self.A_oper_norm_ub)

    def _Get_sigma(self):
        pass

    def _Get_eta(self):
        pass

    def _Get_beta(self):
        pass

    def _Set_F_func(self):
        if self.loss_func == 'l1':
            def loss(us):
                return np.average(np.linalg.norm(us, ord=1, axis=-1))
            self.F_func = loss
        elif self.loss_func == 'piecewise':
            def loss(us):
                return np.average(np.max(np.matmul(us, self.piece_as.transpose()) + self.piece_bs, axis=-1))
            self.F_func = loss
        self.exp_param['F_func'] = self.F_func

    def Solve_x(self, data=None):
        pass

    def Run(self):
        self._Set_F_func()

        t = time.time()
        x_hat = self.Solve_x()
        solving_time = time.time() - t

        cost = Evaluator(self.exp_param, 10000).Evaluate(x_hat)['cost']
        l2_dist = np.linalg.norm(x_hat - self.x_star, 2)
        output = {f'x_{i}': v for i, v in enumerate(x_hat)}
        output.update({'cost': cost,
                       'n': self.n,
                       'l2_dist': l2_dist,
                       'solving_time': solving_time})
        x_hat_nonpriv = ERM(self.exp_param).Solve_x(data=self.samples)
        output.update({f'x_{i}_erm': v for i, v in enumerate(x_hat_nonpriv)})
        return output


class GP_NSGD(GradientPert):
    # Noisy-SGD
    # 2020, Bassily, NeurIPS, Stability-of-stochastic-gradient-descent-on-nonsmooth-convex-losses-Paper
    # Remark 5.3, Algo 2
    def __init__(self, exp_param, algo_param):
        super().__init__(exp_param, algo_param)

        self.T = self.n ** 2
        self.sigma = self._Get_sigma()
        self.sigma_square = self.sigma ** 2
        self.eta = self._Get_eta()
        self.Gaussian_d = multivariate_normal(cov=self.sigma_square)
        if exp_param['loss_func'] == 'l1':
            self.Grad = self.Grad_l1
        if exp_param['loss_func'] == 'piecewise':
            self.Grad = self.Grad_piecewise

    def Proj(self, point):
        point = point * self.D / max(self.D, np.linalg.norm(point, 2))
        return point

    def Grad_l1(self, x):
        sample = choice(self.samples)
        y, A = sample[:, 0], sample[:, 1:]
        u = y - np.matmul(A, x)
        grad = - np.matmul(A.T, np.sign(u))
        return grad

    def Grad_piecewise(self, x):
        sample = choice(self.samples)
        y, A = sample[:, 0], sample[:, 1:]
        u = y - np.matmul(A, x)
        i_argmax = np.argmax(np.matmul(self.piece_as, u) + self.piece_bs)
        grad = - np.matmul(A.T, self.piece_as[i_argmax, :])
        return grad

    def _Get_sigma(self):
        T, L, n = self.T, self.L, self.n
        sigma_square = 8 * L ** 2 * np.log(1 / self.delta) / self.varepsilon ** 2
        return np.sqrt(sigma_square)

    def _Get_eta(self):
        return self.D / (self.L * self.n * np.fmax(np.sqrt(self.n),
                                                   np.sqrt(self.d * np.log(1/self.delta)) / self.varepsilon)
                         )

    def Solve_x(self, data=None):
        self._Set_F_func()

        if data is None:
            self.samples = self.sample_generator.rvs(self.n)
        else:
            self.samples = data
        T, d = self.T, self.d,
        x, x_avg = np.zeros(d), np.zeros(d)
        for t in range(T):
            if t % 50000 == 0:
                print(t, datetime.now())
            w = norm.rvs(loc=0, scale=self.sigma, size=d)
            grad = self.Grad(x)
            x = self.Proj(x - self.eta * (grad + w))
            x_avg = (t * x_avg + x) / (t + 1)
        self.x = x_avg
        return self.x


class GP_Moreau(GradientPert):
    # Algorithm 1 in
    # Bassily R, Feldman V, Talwar K, et al. Private stochastic convex optimization with optimal
    # rates[J]. Advances in neural information processing systems, 2019, 32.
    def __init__(self, exp_param, algo_param):
        super().__init__(exp_param, algo_param)
        if self.grad_type == 'moreau':
            self.L = 2 * self.L

        # -- Moreau
        self.T = round(
            np.ceil(min(self.n / 8, self.varepsilon ** 2 * self.n ** 2 / 32 / self.d / np.log(1 / self.delta))))
        self.beta = self.L / self.D * min(np.sqrt(self.n / 4),
                                                self.varepsilon * self.n / 8 / np.sqrt(self.d * np.log(1 / self.delta)))
        self.sigma = np.sqrt(8 * self.T * self.L ** 2 * np.log(1 / self.delta) / self.n ** 2 / self.varepsilon ** 2)
        self.sigma_square = self.sigma ** 2
        self.eta = self.D / self.L / np.sqrt(self.T)
        self.batch_size = round(np.ceil(max(self.n * np.sqrt(self.varepsilon / 4 / self.T), 1)))

    def Grad(self, x):
        beta, batch_size, F_func = self.beta, self.batch_size, self.exp_param['F_func']

        def prox_f_beta(sample, x, beta):
            y, A = sample[:, 0], sample[:, 1:]

            obj = lambda v: F_func(y - np.matmul(A, v)) / beta + 1 / 2 * np.linalg.norm(x - v, 2) ** 2
            x = minimize(obj, np.zeros(x.shape[0])).x
            return x

        samples = choices(self.samples, k=batch_size)
        proxs = np.asarray([prox_f_beta(sample, x, beta) for sample in samples])
        grad = np.average(beta * (x - proxs), axis=0)
        return grad

    def Solve_x(self, data=None):
        self._Set_F_func()

        if data is None:
            self.samples = self.sample_generator.rvs(self.n)
        else:
            self.samples = data
        T, d = self.T, self.d,
        x, x_avg = np.zeros(d), np.zeros(d)
        for t in range(T):
            if t % 10 == 0:
                print(t, datetime.now())
            w = norm.rvs(loc=0, scale=self.sigma, size=d)
            grad = self.Grad(x)
            x = x - self.eta * (grad + w)
            x_avg = (t * x_avg + x) / (t + 1)
        self.x = x_avg
        return self.x


class GP_Phased_ERM(GradientPert):
    # Algorithm 3 in
    # Feldman V, Koren T, Talwar K. Private stochastic convex optimization: optimal rates in linear time[
    # C]//Proceedings of the 52nd Annual ACM SIGACT Symposium on Theory of Computing. 2020: 439-449.
    def __init__(self, exp_param, algo_param):
        super().__init__(exp_param, algo_param)
        self.T = np.ceil(np.log2(self.n)).astype('int')
        self.eta = self.D / self.L * min(4 / np.sqrt(self.n),
                                         self.varepsilon / np.sqrt(self.d * np.log(1 / self.delta))
                                         )
        self.sigma = None

    def Grad(self, samples, x, n_i, eta_i):
        ys, As = samples[:, :, 0], samples[:, :, 1:]
        F_func = self.exp_param['F_func']
        obj = lambda w: (F_func(ys - np.matmul(As, w))) + np.linalg.norm(w - x, 2) ** 2 / (n_i * eta_i)
        x = minimize(obj, np.zeros(x.shape[0])).x
        return x

    def Solve_x(self, data=None):
        self._Set_F_func()

        if data is None:
            self.samples = self.sample_generator.rvs(self.n)
        else:
            self.samples = data

        samples = deepcopy(self.samples)
        T, d = self.T, self.d
        w = ERM(deepcopy(self.exp_param)).Solve_x(sample_size=20)

        for i in range(1, T + 1):
            n_i, eta_i = round(2 ** (-i) * self.n), 4 ** (-i) * self.eta
            sigma_i = 4 * self.L * eta_i / self.varepsilon * np.sqrt(np.log(1 / self.delta))

            samples_1, samples = samples[:n_i, :, :], samples[n_i:, :, :]
            w = self.Grad(samples_1, w, n_i, eta_i) + norm.rvs(loc=0, scale=sigma_i, size=d)
        self.x = w
        return self.x


if __name__ == '__main__':
    n = 500
    x_true = np.asarray([0.5, -0.5, 1, -1, 1])
    D = 5
    A_mean = np.asarray([[1, 0.5, 0, 1, 1],
                         [0.5, 0.5, 0, 1, 1],
                         [0, 0, -0.5, 1, 1]])
    A_mean = np.column_stack([A_mean, A_mean])
    x_true = np.concatenate([x_true, x_true])
    m, d = A_mean.shape
    A_oper_norm_ub = 3 * A_mean.shape[1] / 5

    exp_param = {'n': n, 'm': m, 'd': d,
                  'x_true': x_true, 'A_mean': A_mean, 'A_oper_norm_ub': A_oper_norm_ub, 'D': D,
                  'loss_func': 'piecewise',  # options: l1, piecewise
                 'pieces': np.asarray([[1, 1, 1, 0],
                                       [1, 1, -1, 0],
                                       [1, -1, 1, 0],
                                       [-1, 1, 1, 0],
                                       [1, -1, -1, 0]])
                 }

    x_star = ERM(exp_param).Solve_x(100000)
    # -- algo params
    algo_param = {'varepsilon': 1, 'delta': 0.05,
                  'grad_type': 'moreau'}
    print(Evaluator(exp_param, 50000).Evaluate(x_star))

    # ensure same datasets

    t = datetime.now()
    # nsgd = GP_NSGD(exp_param, algo_param)
    # print(nsgd.Solve_x())

    moreau = GP_Moreau(exp_param, algo_param)
    print(moreau.Solve_x())

    print(datetime.now() - t)
    print('----')










