from copy import deepcopy
from datetime import datetime
import scipy.integrate as integrate
import numpy as np
import scipy.optimize
import torch
import torchquad
from cubature import cubature
from scipy.stats import norm, truncnorm, multivariate_normal
# from mosek.fusion import Model, Domain, Expr, ObjectiveSense
import time

from models.utils_AOP import Sample_Generator, Evaluator
from models.ERM import ERM


class ObjectivePert(object):
    def __init__(self, exp_param, algo_param):
        self.algo_param = algo_param
        self.exp_param = exp_param

        for k, v in algo_param.items():
            setattr(self, k, v)
        for k, v in exp_param.items():
            setattr(self, k, v)

        if exp_param['loss_func'] == 'l1':
            self.L = self.A_oper_norm_ub * np.sqrt(self.m)
        elif exp_param['loss_func'] == 'piecewise':
            self.L = self.A_oper_norm_ub * np.max(np.linalg.norm(self.pieces, ord=2, axis=1))

        self.sigma = self._Get_sigma()
        self.sigma_square = self.sigma ** 2
        self.lamda = self._Get_lamda()
        self.beta = self._Get_beta()
        self.mu = self._Get_mu()
        self.Gaussian_m = multivariate_normal(np.zeros(self.m))
        self.Gaussian_d = multivariate_normal(mean=np.zeros(self.d), cov=self.sigma_square)

        self.model = None
        self.vars = {}
        self.samples = None
        self.x = None

        self.sample_generator = Sample_Generator(self.x_true, self.A_mean, A_oper_norm_ub=self.A_oper_norm_ub)

        if self.gpu_on is True:
            torchquad.set_up_backend('torch', data_type='float32')
            torch.set_default_dtype(torch.float32)

        if self.loss_func == 'piecewise':
            self.piece_as, self.piece_bs = self.pieces[:, :-1], self.pieces[:, -1]

    def _Get_lamda(self):
        lamda = np.sqrt(2 * self.L ** 2 / self.n + self.d * self.sigma ** 2 / (self.n ** 2)) / self.D
        return lamda

    def _Get_sigma(self):
        sigma = np.sqrt(self.L ** 2 * (8 * np.log(2 / self.delta) + 4 * self.varepsilon) / self.varepsilon ** 2)
        return sigma

    def _Get_beta(self):
        return self.lamda * self.n * self.varepsilon / self.m

    def _Get_mu(self):
        mu = self.L / self.beta
        return mu

    @staticmethod
    def Gaussian_pdf_tensor(v):
        d = v.size()[1]
        return (2 * torch.pi) ** (-d/2) * torch.exp(-torch.pow(torch.linalg.norm(v, ord=2, dim=1), 2) / 2)

    def F_beta_l1(self, x):
        us = self.ys - np.matmul(self.As, x)

        def integrand(v_array, us):
            return np.average(np.linalg.norm(us + self.mu * v_array, ord=1, axis=1)) * self.Gaussian_m.pdf(v_array)

        t = time.time()
        val, err = cubature(integrand, ndim=self.m, fdim=1,
                            xmin=[-20]*self.m, xmax=[20]*self.m,
                            abserr=5e-5, relerr=5e-5,
                            kwargs={'us': us})
        print('Cubature took {} seconds'.format(time.time() - t))
        return val

    def integrand_l1_GPU(self, x):
        def f(v, x=x):
            us = self.ys_tensor - torch.matmul(self.As_tensor, torch.asarray(x, device=torch.device('cuda')))
            us = us.repeat([v.size()[0], 1, 1])

            pdfs = self.Gaussian_pdf_tensor(v)
            # pdfs = pdfs.cuda()

            v_reshaped = torch.reshape(torch.repeat_interleave(v, repeats=torch.tensor([self.n] * v.size()[0]), dim=0), (v.size()[0], self.n, self.m))
            # v_reshaped = v_reshaped.cuda()

            return torch.mean(torch.linalg.norm(us + self.mu * v_reshaped,
                                                ord=1, dim=2),
                              dim=1) * pdfs
        return f

    def F_beta_l1_GPU(self, x):
        self.N_points = self.N_eachdomain_conv**self.m  # OOM if N_points too large. use GPU memory
        # t = time.time()
        result = torchquad.Boole().integrate(self.integrand_l1_GPU(x),
                                                 dim=self.m,
                                                 N=self.N_points,  # number of points to approximate quadrature
                                                 integration_domain=torch.Tensor([[-10, 10] for _ in range(self.m)])
                                                 )
        # print('torchquad took {} seconds'.format(time.time() - t))
        # print(result.item())
        return result.item()

    def F_beta_piecewise(self, x):
        us = self.ys - np.matmul(self.As, x)

        def integrand(v_array, us):
            return np.average(np.max(np.matmul(us + self.mu * v_array, self.piece_as.transpose()) + self.piece_bs, axis=1)) * self.Gaussian_m.pdf(v_array)

        t = time.time()
        val, err = cubature(integrand, ndim=self.m, fdim=1,
                            xmin=[-20]*self.m, xmax=[20]*self.m,
                            abserr=5e-5, relerr=5e-5,
                            kwargs={'us': us})
        print('Cubature took {} seconds'.format(time.time() - t))
        return val

    def integrand_piecewise_GPU(self, x):
        def f(v, x=x):
            us = self.ys_tensor - torch.matmul(self.As_tensor, torch.asarray(x, device=torch.device('cuda')))
            us = us.repeat([v.size()[0], 1, 1])

            pdfs = self.Gaussian_pdf_tensor(v)
            # pdfs = pdfs.cuda()

            v_reshaped = torch.reshape(torch.repeat_interleave(v, repeats=torch.tensor([self.n] * v.size()[0]), dim=0), (v.size()[0], self.n, self.m))
            # v_reshaped = v_reshaped.cuda()

            return torch.mean(torch.max(torch.matmul(us + self.mu * v_reshaped, self.piece_as_tensor.T), # + torch.repeat_interleave(self.piece_bs_tensor, repeats=torch.tensor(v.size()[0], self.n)),
                                        dim=2).values,
                              dim=1) * pdfs
        return f

    def F_beta_piecewise_GPU(self, x):
        self.N_points = self.N_eachdomain_conv ** self.m
        # t = time.time()
        result = torchquad.Boole().integrate(self.integrand_piecewise_GPU(x),
                                             dim=self.m,
                                             N=self.N_points,  # number of points to approximate quadrature
                                             integration_domain=torch.Tensor([[-10, 10] for _ in range(self.m)])
                                             )
        # print('torchquad took {} seconds'.format(time.time() - t))
        # print(result.item())
        return result.item()

    def Build_Model(self, data=None):
        if data is not None:
            self.samples = np.asarray(data)
        else:
            self.samples = self.sample_generator.rvs(self.n)
            self.ys, self.As = self.samples[:, :, 0], self.samples[:, :, 1:]

        if self.gpu_on:
            self.As_tensor = torch.asarray(self.As, device=torch.device('cuda'))
            self.ys_tensor = torch.asarray(self.ys, device=torch.device('cuda'))
            if self.loss_func == 'piecewise':
                self.piece_as_tensor = torch.asarray(self.piece_as, dtype=torch.double, device=torch.device('cuda'))
                self.piece_bs_tensor = torch.asarray(self.piece_bs, dtype=torch.float32, device=torch.device('cuda'))

        b_noise = self.Gaussian_d.rvs()

        if self.loss_func == 'l1':
            self.model = lambda x: self.F_beta_l1(x) + self.lamda * np.linalg.norm(x, 2) ** 2 + np.dot(x, b_noise) / self.n
            if self.gpu_on:
                self.model = lambda x: self.F_beta_l1_GPU(x) + self.lamda * np.linalg.norm(x, 2) ** 2 + np.dot(x, b_noise) / self.n
        if self.loss_func == 'piecewise':
            self.model = lambda x: self.F_beta_piecewise(x) + self.lamda * np.linalg.norm(x, 2) ** 2 + np.dot(x, b_noise) / self.n
            if self.gpu_on:
                self.model = lambda x: self.F_beta_piecewise_GPU(x) + self.lamda * np.linalg.norm(x, 2) ** 2 + np.dot(x,b_noise) / self.n
        return self.model

    def Solve_x(self, verbose=False, data=None):
        self.Build_Model(data=data)
        init_guess = ERM(self.exp_param).Solve_x(data=self.samples)
        res = scipy.optimize.minimize(self.model, init_guess, method='SLSQP',
                                      tol=1e-8, options={'disp': verbose})
        self.x = res.x
        return res.x

    def Run(self):
        t = time.time()
        x_hat_private = self.Solve_x()
        solving_time = time.time() - t
        print(f'---n={self.n}---solving_time:', solving_time)
        cost_private = Evaluator(self.exp_param, n=10000).Evaluate(x_hat_private)['cost']
        l2_dist = np.linalg.norm(x_hat_private - self.x_star, 2)
        output = {f'x_{i}': v for i, v in enumerate(x_hat_private)}
        output.update({'cost': cost_private,
                       'n': self.n, 'l2_dist': l2_dist,
                       'solving_time': solving_time})

        x_hat_nonpriv = ERM(self.exp_param).Solve_x(data=self.samples)
        output.update({f'x_{i}_erm': v for i, v in enumerate(x_hat_nonpriv)})
        return output


if __name__ == '__main__':
    # -- exp params
    n = 500
    x_true = np.asarray([0.5, -0.5, 1, -1, 1])
    D = 5
    A_mean = np.asarray([[1, 0.5, 0, 1, 1],
                         [0.5, 0.5, 0, 1, 1],
                         [0, 0, -0.5, 1, 1]])
    # A_mean = np.column_stack([A_mean, A_mean])
    # x_true = np.concatenate([x_true, x_true])
    m, d = A_mean.shape
    A_oper_norm_ub = 3 * A_mean.shape[1] / 5

    exp_param = {'n': n, 'm': m, 'd': d,
                 'x_true': x_true, 'A_mean': A_mean, 'A_oper_norm_ub': A_oper_norm_ub, 'D': D,
                 'loss_func': 'l1',  # options: l1, piecewise
                 'gpu_on': True,}
     #'pieces': np.asarray([[1, 1, 1, 0],
     #                     [1, 1, -1, 0],
     #                    [1, -1, 1, 0],
     #                     [-1, 1, 1, 0],
     #                     [1, -1, -1, 0]])}
    x_star = ERM(exp_param).Solve_x(300)
    # exp_param['x_star'] = x_star
    # a = Evaluator(exp_param, 10**5).Evaluate(x_true) # [2.3967, 2.3972, 2.3905, 2.3935]
    # -- algo params
    algo_param = {'varepsilon': 0.2, 'delta': 0.05,
                  'gpu_on': True,
                  'N_eachdomain_conv': 60}

    op = ObjectivePert(exp_param, algo_param)
    # op.Build_Model()
    # print(ERM(exp_param).Solve_x(data=op.sample))
    t = datetime.now()
    print(op.Solve_x())
    # print(op.Run())
    print(datetime.now() - t)
    print('----')
    # output = op.Run()


