from src.portfolio_optimization.load_data import SP500DataLoader, generate_dataset
from cvxpylayers.torch import CvxpyLayer
from src.problem_quadratic import QuadraticProblem
import datetime as dt
import numpy as np
import cvxpy as cp
import torch


class PortfolioProblem(QuadraticProblem):

    def __init__(self,
                 n_variables: int,
                 path_to_data: str,
                 n_samples: int,
                 proximity_method: str = 'max',
                 alpha: float = 2,
                 use_train_only=True,
                 train_ratio=0.7,
                 validate_ratio=0.1,
                 data_sample_seed=None,
                 ):
        if data_sample_seed is not None:
            np.random.seed(data_sample_seed)
        sp500_data = SP500DataLoader(path_to_data, "sp500",
                                     start_date=dt.datetime(2004, 1, 1), end_date=dt.datetime(2017, 1, 1),
                                     collapse="daily", overwrite=False, verbose=True)
        dataset = generate_dataset(sp500_data, n=n_variables, n_samples=n_samples, alpha=alpha)
        super().__init__(n_variables, dataset,  use_train_only, train_ratio, validate_ratio)
        self.observation_size = 28
        self.proximity_method = proximity_method

    def _build_model(self):
        self.params_dict = {'w_lin': cp.Parameter((self.n_variables,), name='w_lin'),
                            'W_sq': cp.Parameter((self.n_variables, self.n_variables), name='W_sq')}
        self.params_list = ['w_lin', 'W_sq']
        x = cp.Variable(self.n_variables)
        constraints = [x >= 0, x <= 1, cp.sum(x) == 1]
        # Objective
        linear_term = self.params_dict['w_lin'] @ x
        square_term = cp.sum_squares(self.params_dict['W_sq'].T @ x)
        objective = cp.Maximize(linear_term - 0.5 * square_term)
        self.model = cp.Problem(objective, constraints)
        self.layer = CvxpyLayer(self.model, parameters=[self.params_dict[key] for key in self.params_list],
                                variables=self.model.variables())


    def get_x_u(self, prediction_dict, param_values_batch):
        if 'x_u' in prediction_dict:
            return prediction_dict['x_u']
        W_sq = prediction_dict['W_sq']
        w_lin = prediction_dict['w_lin']
        Q_pred = W_sq @ W_sq.swapdims(1, 2)
        Q_pred_inverse = torch.inverse(Q_pred)
        center = (w_lin[:, None, :] @ Q_pred_inverse[:])[:, 0]
        return center

    def get_penalty(self, x_u, param_values_batch, proximity_method_to_use=None):
        if proximity_method_to_use is None:
            proximity_method_to_use = self.proximity_method
        penalty_tensor = torch.concat([torch.relu(0 - x_u), torch.relu(x_u - 1),
                                       torch.abs(x_u.sum(dim=1, keepdims=True) - 1)], dim=1)
        if proximity_method_to_use == 'mean':
            return penalty_tensor.mean(dim=1)
        elif proximity_method_to_use == 'max':
            return penalty_tensor.amax(dim=1)
        else:
            return penalty_tensor

    def get_reward(self, solution, param_values_batch):
        linear_term = torch.sum(param_values_batch['w_lin'] * solution, dim=1)
        square_term = torch.sum(torch.bmm(param_values_batch['W_sq'].transpose(2, 1),
                                          solution[:, :, None])[:, :, 0] ** 2, 1)
        return linear_term -.5 * square_term

