import random
import math
import numpy as np

from src.utils.probabilities import bernoulli
from src.utils.strings import *

from src.solvers.kp_values_solver import Solver
from src.generators.kp_generator import KnapsackGenerator


class KnapsackWeightsGenerator(KnapsackGenerator):

    def __init__(self, name: str, solver: Solver, input_dim: int, output_dim: int, deg: int = 2,
                 multiplicative_noise: float = 0.5, additive_noise: float = 0, relative_capacity: float = 0.5,
                 correlate_values_and_weights: int = 0, rho: float = 0.5, penalty: float = 1.0):

        super().__init__(name, solver, input_dim, output_dim, deg, multiplicative_noise, additive_noise,
                         relative_capacity, correlate_values_and_weights, rho)

        self._penalty = penalty

        self._values = None
        self._capacity = None

        self._bernoulli_matrix = None

    def _pre_generation(self, path: str, num_instances: int, seed: int | None) -> None:

        super()._pre_generation(path, num_instances, seed)
        self._bernoulli_matrix = np.array([[bernoulli(0.5) for _ in range(self._input_dim)]
                                           for _ in range(self._output_dim)])

        self._values = np.random.uniform(0, 1, self._output_dim)

        self._generate_capacity()

    def _generate_instance(self) -> dict:

        x = np.array([round(random.gauss(0, 1), 3) for _ in range(self._input_dim)])
        weights = self._generate_weights(x)

        params: dict[str, np.ndarray] = {VALUES: self._values, CAPACITY: self._capacity, PENALTY: self._penalty}

        solution, _ = self._solver.solve(x=x, y=weights, params=params)
        metrics = self._solver.compute_metrics(weights, solution, params)

        instance = dict()
        instance[INPUT] = x
        instance[VALUES] = self._values
        instance[WEIGHTS] = weights
        instance[CAPACITY] = self._capacity
        instance[SOLUTION] = solution
        instance[COST] = metrics[TOTAL_COST]

        return instance

    def _generate_capacity(self) -> None:

        weights_sums = []

        for _ in range(10):
            x = np.array([round(random.gauss(0, 1), 3) for _ in range(self._input_dim)])
            weights = self._generate_weights(x)
            weights_sums.append(np.sum(weights))

        avg_weight_sum = np.mean(weights_sums)

        self._capacity = avg_weight_sum * self._relative_capacity

    def _generate_weights(self, x: np.ndarray) -> np.ndarray:

        b_matmul_x = np.matmul(self._bernoulli_matrix, x)

        weights = []
        for i in range(self._output_dim):

            pred = b_matmul_x[i]

            weight = 1 + (pred / math.sqrt(self._input_dim) + 3) ** self._deg
            weight = weight * random.uniform(1 - self._multiplicative_noise, 1 + self._multiplicative_noise)

            if self._correlate_values_and_weights == 1:
                weight *= self._values[i]
            elif self._correlate_values_and_weights == 2:
                weight = self._rho * self._values[i] + (1 - self._rho) * weight

            weight = round(weight + self._additive_noise, 5)

            weights.append(weight)

        weights = np.array(weights)

        return weights
