import math
from copy import deepcopy
from datetime import datetime

import numpy as np
import scipy.optimize
from scipy import integrate
from scipy.stats import multivariate_normal, norm

import time


class Sample_Generator(object):
    def __init__(self, x_true, A_mean, A_oper_norm_ub=5, epsilon=None):
        self.m, self.d = A_mean.shape
        self.x_true = x_true
        self.A_mean = A_mean
        self.A_oper_norm_ub = A_oper_norm_ub
        if epsilon == None:
            self.epsilon = norm()

    def rvs(self, size):
        return self.Draw_Samples(size)

    def Draw_Samples(self, size):
        As = multivariate_normal.rvs(mean=self.A_mean.flatten(), size=15 * size).reshape([15*size, self.A_mean.shape[0], self.A_mean.shape[1]])
        As = As[scipy.linalg.norm(As, ord=2, axis=(1, 2)) <= self.A_oper_norm_ub][:size]
        ys = np.matmul(As, self.x_true) + self.epsilon.rvs(size=(size, self.m))
        samples = np.concatenate([ys.reshape(size, self.m, 1), As], axis=-1)
        return samples


class Evaluator(object):
    def __init__(self, exp_param, n=10000):
        self.exp_param = exp_param
        self.n_outsample = n
        for k, v in exp_param.items():
            setattr(self, k, v)
        self.samples_generator = Sample_Generator(self.x_true, self.A_mean, A_oper_norm_ub=self.A_oper_norm_ub)

    def Evaluate(self, value_to_be_evaluated):
        if isinstance(value_to_be_evaluated, np.ndarray) or isinstance(value_to_be_evaluated, list):
            output_dict = self._Evaluate_x(value_to_be_evaluated)
        return output_dict

    def _Evaluate_x(self, x):
        t = time.time()
        self.samples = self.samples_generator.rvs(self.n_outsample)
        # print(f'generate {self.n_outsample} samples takes {time.time()-t}')
        ys, As = self.samples[:, :, 0], self.samples[:, :, 1:]
        us = ys - np.matmul(As, x)
        if self.exp_param['loss_func'] == 'l1':
            avg_cost = np.average(np.linalg.norm(us, ord=1, axis=1))
        if self.exp_param['loss_func'] == 'piecewise':
            avg_cost = np.average(np.max(np.matmul(us, self.pieces[:, :-1].transpose()) + self.pieces[:, -1], axis=-1))

        output_dict = {'cost': avg_cost}  # 'us_est': us_est
        return output_dict




if __name__ == '__main__':
    # -- exp param
    n, m, d = np.inf, 3, 5
    x_star = np.asarray([0.5, -0.5, 1, -1, 2])
    D = 5
    A_mean = np.asarray([[1, 0.5, 0, 1, 1],
                         [0.5, 0.5, 0, 1, 1],
                         [0, 0, -0.5, 1, 1]])
    A_oper_norm_ub = 5

    exp_param = {'n': n, 'm': m, 'd': d,
                  'x_star': x_star, 'A_mean': A_mean, 'A_oper_norm_ub': A_oper_norm_ub, 'D': D,
                  'loss_func': 'l1'}  # options: l1, piecewise

    # -- algo param
    algo_param = {'varepsilon': 1, 'delta': 0.1}