from copy import deepcopy
import time
import numpy as np
from models.utils_AOP import Evaluator, Sample_Generator
import scipy


class ERM(object):
    def __init__(self, exp_param):
        self.exp_param = deepcopy(exp_param)
        for k, v in exp_param.items():
            setattr(self, k, v)
        self.model = None
        self.vars = {}
        self.sample_generator = Sample_Generator(self.x_true, self.A_mean, A_oper_norm_ub=self.A_oper_norm_ub)

    def Solve_x(self, sample_size=None, data=None):
        if sample_size is not None:
            self.exp_param['n'] = sample_size
            self.n = sample_size
            self.samples = self.sample_generator.rvs(self.n)
        if data is not None:
            self.samples = data
            self.exp_param['n'] = data.shape[0]
            self.n = data.shape[0]

        self._Build()
        res = scipy.optimize.minimize(self.model, np.zeros(self.d), tol=10**-8)
        self.x_hat = res.x
        return res.x

    def _Build(self):
        ys, As = self.samples[:, :, 0], self.samples[:, :, 1:]

        def obj_l1(x):
            us = ys - np.matmul(As, x)
            obj_func = np.average(np.linalg.norm(us, ord=1, axis=1))
            return obj_func

        def obj_piecewise(x):
            us = ys - np.matmul(As, x)
            obj_func = np.average(np.max(np.matmul(us, self.pieces[:, :-1].transpose()) + self.pieces[:, -1], axis=-1))
            return obj_func

        if self.loss_func == 'l1':
            self.model = obj_l1
        if self.loss_func == 'piecewise':
            self.model = obj_piecewise

    def Run(self):
        # -- solve
        t = time.time()
        x_hat = self.Solve_x(sample_size=self.n)
        solving_time = time.time() - t
        # -- output
        evaluator_output = Evaluator(self.exp_param, 10 ** 5).Evaluate(x_hat)
        l2_dist = np.linalg.norm(x_hat - self.x_star, 2)
        output = {f'x_{i}': v for i, v in enumerate(x_hat)}
        output.update({'cost': evaluator_output['cost'], 'n': self.n, 'l2_dist': l2_dist, 'solving_time': solving_time})
        return output



