from src.vopf.random_grid import generate_random_grid
from src.vopf.load_data import generate_dataset
from src.problem_quadratic import QuadraticProblem


class VOPFProblem(QuadraticProblem):

    def __init__(self,
                 n_loads: int,
                 n_generators: int,
                 n_samples: int,
                 proximity_method: str = 'max',
                 use_train_only: bool = True,
                 train_ratio: float = 0.7,
                 validate_ratio: float = 0.1,
                 data_sample_seed: float = None,
                 features_rank: int = None,
                 topology_seed: float = None,
                 g: float = 6,
                 i_max_mean: float = 1000,
                 branching_factor: float = 0,
                 norm_factor: float = 1,
                 vary_constraints: bool = False,
                 noise_std: float = 0.2):

        self.n_generators = n_generators
        self.n_loads = n_loads
        self.noise_std = noise_std
        self.generator_inds, self.load_inds, self.conductance_matrix = generate_random_grid(n_loads, n_generators, g,
                                                                                            topology_seed,
                                                                                            branching_factor)
        dataset = generate_dataset(self.conductance_matrix, self.generator_inds, self.load_inds, n_samples, i_max_mean,
                                   norm_factor, data_sample_seed, vary_constraints, features_rank, noise_std=noise_std)

        self.n_constraints = dataset[0]['A'].shape[1]
        self.proximity_method = proximity_method
        super().__init__(n_loads + n_generators - 1, dataset, use_train_only, train_ratio,
                         validate_ratio, proximity_method, norm_factor)
        self.observation_size = dataset[0]['features'].shape[1]

    def get_reward(self, solution, param_values_batch):
        r0 = (param_values_batch['w_lin_0'] * (350 / self.norm_factor)).sum(1)
        return r0 + super(VOPFProblem, self).get_reward(solution, param_values_batch)