import random
import math
import numpy as np

from src.utils.probabilities import bernoulli
from src.utils.strings import *

from src.solvers.kp_values_solver import Solver
from src.generators.kp_generator import KnapsackGenerator


class KnapsackValuesGenerator(KnapsackGenerator):

    def __init__(self, name: str, solver: Solver, input_dim: int, output_dim: int, deg: int = 2,
                 multiplicative_noise: float = 0.5, additive_noise: float = 0, relative_capacity: float = 0.5,
                 correlate_values_and_weights: int = 0, rho: float = 0.5):

        super().__init__(name, solver, input_dim, output_dim, deg, multiplicative_noise, additive_noise,
                         relative_capacity, correlate_values_and_weights, rho)

        self._weights = None
        self._capacity = None

        self._bernoulli_matrix = None

    def _pre_generation(self, path: str, num_instances: int, seed: int | None) -> None:

        super()._pre_generation(path, num_instances, seed)
        self._bernoulli_matrix = np.array([[bernoulli(0.5) for _ in range(self._input_dim)]
                                           for _ in range(self._output_dim)])

        self._weights = np.random.uniform(0, 1, self._output_dim)
        self._capacity = self._weights.sum() * self._relative_capacity

    def _generate_instance(self) -> dict:

        x = np.array([round(random.gauss(0, 1), 3) for _ in range(self._input_dim)])
        b_matmul_x = np.matmul(self._bernoulli_matrix, x)

        values = []
        for i in range(self._output_dim):

            pred = b_matmul_x[i]

            val = 1 + (pred / math.sqrt(self._input_dim) + 3) ** self._deg
            val = val * random.uniform(1 - self._multiplicative_noise, 1 + self._multiplicative_noise)

            if self._correlate_values_and_weights == 1:
                val *= self._weights[i]
            elif self._correlate_values_and_weights == 2:
                val = self._rho * self._weights[i] + (1 - self._rho) * val

            val = round(val + self._additive_noise, 5)

            values.append(val)

        values = np.array(values)
        params: dict[str, np.ndarray] = {WEIGHTS: self._weights, CAPACITY: self._capacity}

        solution, _ = self._solver.solve(x=x, y=values, params=params)
        metrics = self._solver.compute_metrics(values, solution, params)

        instance = dict()
        instance[INPUT] = x
        instance[VALUES] = values
        instance[WEIGHTS] = self._weights
        instance[CAPACITY] = self._capacity
        instance[SOLUTION] = solution
        instance[COST] = metrics[TOTAL_COST]

        return instance
