import random
import math
import numpy as np

from src.utils.probabilities import bernoulli
from src.utils.strings import *

from src.solvers.kp_values_solver import Solver
from src.generators.kp_generator import KnapsackGenerator


class KnapsackCapacityGenerator(KnapsackGenerator):

    def __init__(self, name: str, solver: Solver, input_dim: int, output_dim: int, deg: int = 2,
                 multiplicative_noise: float = 0.5, correlate_values_and_weights: int = 0,
                 rho: float = 0.5, penalty: float = 1.0):

        super().__init__(name, solver, input_dim, output_dim, deg, multiplicative_noise, 0.0,
                         0.0, correlate_values_and_weights, rho)

        self._penalty = penalty

        self._values = None
        self._weights = None

        self._bernoulli_matrix = None

    def _pre_generation(self, path: str, num_instances: int, seed: int | None) -> None:

        super()._pre_generation(path, num_instances, seed)
        self._bernoulli_vector = np.array([bernoulli(0.5) for _ in range(self._input_dim)])

        self._values = np.random.uniform(0, 1, self._output_dim)
        self._weights = np.random.uniform(0, 1, self._output_dim)

        if self._correlate_values_and_weights == 1:
            self._values *= self._weights
        elif self._correlate_values_and_weights == 2:
            self._values = self._rho * self._values + (1 - self._rho) * self._values

    def _generate_instance(self) -> dict:

        x = np.array([round(random.gauss(0, 1), 3) for _ in range(self._input_dim)])
        x = np.abs(x)

        b_matmul_x = np.matmul(self._bernoulli_vector, x)

        val = 1 + b_matmul_x ** self._deg
        val *= random.uniform(1 - self._multiplicative_noise, 1 + self._multiplicative_noise)

        relative_capacity_val = np.random.beta(a=val, b=val)
        relative_capacity_val = np.clip(relative_capacity_val, a_min=0.1, a_max=np.inf)

        capacity = round(np.sum(self._weights) * relative_capacity_val)
        capacity = np.array([capacity])

        params: dict[str, np.ndarray] = {VALUES: self._values, WEIGHTS: self._weights, PENALTY: self._penalty}

        solution, _ = self._solver.solve(x=x, y=capacity, params=params)
        metrics = self._solver.compute_metrics(capacity, solution, params)

        instance = dict()
        instance[INPUT] = x
        instance[VALUES] = self._values
        instance[WEIGHTS] = self._weights
        instance[CAPACITY] = capacity
        instance[SOLUTION] = solution
        instance[COST] = metrics[TOTAL_COST]

        return instance
