import numpy as np
import torch


class Problem:

    def __init__(self,
                 n_variables: int,
                 dataset: list,
                 use_train_only: bool = True,
                 train_ratio: float = 0.7,
                 validate_ratio: float = 0.1
                 ):
        # Save parameters
        self.n_variables = n_variables
        if not hasattr(self, 'observation_size'):
            self.observation_size = 0
        # Save data related things
        self.dataset = dataset
        self.train_ratio = 1 if use_train_only else train_ratio
        self.validate_ratio = 0 if use_train_only else validate_ratio
        self.train_size = int(np.round(len(self.dataset) * self.train_ratio))
        self.validate_size = int(np.round(len(self.dataset) * self.validate_ratio))
        self.test_size = len(self.dataset) - self.train_size - self.validate_size
        self.dataset_indices = {'train': np.arange(0, self.train_size),
                                'validate': np.arange(self.train_size, self.train_size + self.validate_size),
                                'test': np.arange(self.train_size + self.validate_size, len(self.dataset))}
        self.dataset_keys = list(self.dataset[0].keys()) + ['r_max', 'optimal_solution']
        self.params_dict = dict()
        self.params_list = list()
        self.model = None
        self.layer = None
        self._build_model()

    def _build_model(self):
        raise NotImplementedError

    def solve(self, param_values_batch):
        layer_input = [param_values_batch[key] for key in self.params_list]
        solution, = self.layer(*layer_input)
        solution_dict = {'solution': solution}
        return solution_dict

    def get_reward_gradient(self, x, param_values_batch):
        raise NotImplementedError

    def get_observation(self, param_values_batch):
        raise NotImplementedError

    def get_x_u(self, prediction_dict, param_values_batch):
        raise NotImplementedError

    def get_penalty(self, x_u, param_values_batch):
        raise NotImplementedError

    def get_reward(self, solution, param_values_batch):
        raise NotImplementedError

    def sample_parameters(self, batch_size=1, mode='train', indices=None):
        assert mode in ['train', 'validate', 'test'], 'mode should be train, test, or validation'
        if indices is None:
            indices = np.random.choice(self.dataset_indices[mode], replace=False, size=batch_size)
        else:
            assert len(indices) == batch_size, 'Provided %d indices, but batch_size=%d' % (len(indices), batch_size)
        for ind in indices:
            if 'optimal_solution' not in self.dataset[ind]:
                solution_dict = self.solve(self.dataset[ind])
                optimal_solution = solution_dict['solution']
                r_max = self.get_reward(optimal_solution, self.dataset[ind])
                self.dataset[ind]['optimal_solution'] = optimal_solution
                self.dataset[ind]['r_max'] = r_max
        sample_batch = {key: torch.concat([self.dataset[i][key] for i in indices]) for key in self.dataset_keys}
        return sample_batch, indices
