import numpy as np
import torch
from cvxpylayers.torch import CvxpyLayer
import cvxpy as cp


class SurrogateProjectionProblem:

    def __init__(self, n_variables, flexible_W_sq=True):
        self.n_variables = n_variables
        self.flexible_W_sq = flexible_W_sq
        self._build_model()

    def _build_model(self):
        params_dict = {'w_lin': cp.Parameter((self.n_variables,), name='w_lin'),
                       'o': cp.Parameter(self.n_variables, name='o'),
                       'r_square': cp.Parameter(1, name='r_square'),
                       }
        self.params_list = ['w_lin', 'o', 'r_square']

        x = cp.Variable(self.n_variables)
        t = cp.Variable(self.n_variables)
        constraints = [cp.sum_squares(t) <= params_dict['r_square'], x - params_dict['o'] == t]
        linear_term = params_dict['w_lin'] @ x

        if self.flexible_W_sq:
            params_dict['W_sq'] = cp.Parameter((self.n_variables, self.n_variables), name='W_sq')
            self.params_list.append('W_sq')
            square_term = cp.sum_squares(params_dict['W_sq'].T @ x)
        else:
            params_dict['gamma_sqrt'] = cp.Parameter(1, name='gamma')
            self.params_list.append('gamma_sqrt')
            square_term = cp.sum_squares(params_dict['gamma_sqrt'] * x)
        objective = cp.Maximize(linear_term - square_term)
        model = cp.Problem(objective, constraints)
        layer = CvxpyLayer(model, parameters=[params_dict[key] for key in self.params_list],
                           variables=model.variables())
        self.model = model
        self.layer = layer

    def solve(self, param_values_batch):
        param_values_batch_to_use = dict(param_values_batch)
        if 'gamma_sqrt' not in param_values_batch:
            param_values_batch_to_use['gamma_sqrt'] = torch.ones(param_values_batch['w_lin'].shape[:-1],
                                                                 dtype=param_values_batch['w_lin'].dtype)
        layer_input = [param_values_batch_to_use[key] for key in self.params_list]
        x, t = self.layer(*layer_input)
        return {'solution': x}
