from src.problem import Problem
from cvxpylayers.torch import CvxpyLayer
import cvxpy as cp
import torch



class QuadraticProblem(Problem):

    def __init__(self,
                 n_variables: int,
                 dataset: list,
                 use_train_only: bool = True,
                 train_ratio: float = 0.7,
                 validate_ratio: float = 0.1,
                 proximity_method: str = 'mean',
                 norm_factor: float = 1,):

        self.n_constraints = dataset[0]['A'].shape[-2]
        self.observation_size = dataset[0]['features'].shape[-1]
        self.proximity_method = proximity_method
        self.norm_factor = norm_factor
        super().__init__(n_variables, dataset, use_train_only, train_ratio, validate_ratio)

    def _build_model(self):
        self.params_dict = {'A': cp.Parameter((self.n_constraints, self.n_variables), name='A'),
                            'b': cp.Parameter((self.n_constraints,), name='b'),
                            'w_lin': cp.Parameter((self.n_variables,), name='w_lin'),
                            'W_sq': cp.Parameter((self.n_variables, self.n_variables), name='W_sq')}
        self.params_list = ['A', 'b', 'w_lin', 'W_sq']
        x = cp.Variable(self.n_variables)
        # Constraints
        constraints = [(self.params_dict['A'] / self.norm_factor) @ x + self.params_dict['b'] / self.norm_factor <= 0]
        # Objective
        linear_term = sum([self.params_dict['w_lin'][i] * x[i] for i in range(self.n_variables)])
        square_term = cp.sum_squares(self.params_dict['W_sq'].T @ x)
        objective = cp.Maximize(linear_term - square_term)
        self.model = cp.Problem(objective, constraints)
        self.layer = CvxpyLayer(self.model, parameters=[self.params_dict[key] for key in self.params_list],
                                variables=self.model.variables())

    def get_observation(self, param_values_batch):
        return param_values_batch['features']

    def get_reward_gradient(self, x, param_values_batch):
        W_sq = param_values_batch['W_sq']
        Q = W_sq @ W_sq.swapdims(1, 2)
        return param_values_batch['w_lin'] - (2 * x @ Q)[:, 0]


    def get_x_u(self, prediction_dict, param_values_batch):
        if 'x_u' in prediction_dict:
            x_u = prediction_dict['x_u']
        else:
            W_sq = prediction_dict['W_sq']
            w_lin = prediction_dict['w_lin']
            Q_pred = W_sq @ W_sq.swapdims(1, 2)
            Q_pred_inverse = torch.inverse(Q_pred)
            x_u = (w_lin[:, None, :] @ Q_pred_inverse[:])[:, 0] / 2
        return x_u

    def get_reward(self, solution, param_values_batch):
        linear_term = torch.sum(param_values_batch['w_lin'] * solution, dim=1)
        square_term = torch.sum(torch.bmm(param_values_batch['W_sq'].transpose(2, 1),
                                          solution[:, :, None])[:, :, 0] ** 2, 1)
        return linear_term - square_term

    def get_penalty(self, x_u, param_values_batch, proximity_method_to_use=None):
        if proximity_method_to_use is None:
            proximity_method_to_use = self.proximity_method
        penalty_tensor = torch.relu(torch.bmm(param_values_batch['A'], x_u[:, :, None])[:, :, 0]
                                    + param_values_batch['b'])
        if proximity_method_to_use == 'mean':
            return penalty_tensor.mean(dim=1)
        elif proximity_method_to_use == 'max':
            return penalty_tensor.amax(dim=1)
        else:
            return penalty_tensor

    def get_solution_distance(self, x_u, param_values_batch):
        optimal_solution = param_values_batch['optimal_solution'].detach()
        distance_to_solution = torch.norm(optimal_solution - x_u, dim=1)
        return distance_to_solution

    def get_projection_distance(self, x_u, solution, param_values_batch):
        projection_distance = torch.norm(solution.detach() - x_u, dim=1)
        projection_distance[projection_distance <= 1e-5] = 0
        return projection_distance

    def get_mse(self, prediction_dict, param_values_batch):
        per_prediction_mse = []
        for key, val in prediction_dict.items():
            if key not in param_values_batch:
                continue
            val_true = param_values_batch[key]
            if len(val.shape) <= 2:
                mse = torch.nn.functional.mse_loss(val, val_true, reduction='none').mean(dim=[1])
            elif len(val.shape) == 3:
                mse = torch.norm(val - val_true, dim=[1, 2])
            else:
                raise TypeError('Shape of %s is %s. Only 2- and 3- dim predictions are supported! ' % (key, val.shape))
            per_prediction_mse.append(mse)
        total_mse = torch.stack(per_prediction_mse).sum(dim=0)
        return total_mse
