import os
import random
import math
import numpy as np
import pandas as pd
from tqdm import tqdm

from src.solvers.wsmc_solver import WeightedSetMultiCoverSolver

from src.utils.probabilities import set_seed, bernoulli
from src.utils.strings import *


class WeightedSetMultiCoverGenerator:

    def __init__(self, name: str, input_dim: int, n_items: int, n_sets: int, density: float, penalty: float,
                 deg: int, multiplicative_noise: float, additive_noise: float):

        assert 0.0 < density < 1.0
        assert 0.0 < multiplicative_noise < 1.0
        assert n_sets > 0
        assert n_items > 0

        self._name = name
        self._input_dim = input_dim
        self._n_items = n_items
        self._n_sets = n_sets
        self._density = density
        self._penalty = penalty
        self._deg = deg
        self._multiplicative_noise = multiplicative_noise
        self._additive_noise = additive_noise

        self._cover_matrix = None
        self._sets_costs = None
        self._items_costs = None

        self._bernoulli_matrix = None

        self._solver = WeightedSetMultiCoverSolver()

    def generate(self, path: str, num_instances: int, seed: int | None) -> None:

        self._pre_generation(path, num_instances, seed)

        rows = []
        for _ in tqdm(range(num_instances), total=num_instances, desc='Data generation'):
            x, y, z, cost = self._generate_instance()
            row = [x, y, z, cost, self._sets_costs, self._items_costs]
            rows.append(row)

        dataframe = pd.DataFrame(rows, columns=[INPUT, DEMANDS, SOLUTION, COST, SETS_COSTS, ITEMS_COSTS])

        self._save_instance(path, dataframe)

    def _save_instance(self, path: str, dataframe: pd.DataFrame) -> None:

        dataframe_save_path = os.path.join(path, self._name + ".pkl")
        dataframe.to_pickle(dataframe_save_path)

        cover_matrix_save_path = os.path.join(path, self._name + "_cover_matrix.npy")
        np.save(cover_matrix_save_path, self._cover_matrix)

        print("Dataset saved to", dataframe_save_path)

    def _pre_generation(self, path: str, num_instances: int, seed: int | None) -> None:

        assert num_instances >= 1

        if seed is not None:
            set_seed(seed)

        if not os.path.exists(path):
            os.makedirs(path)

        self._build_cover_matrix()
        self._build_sets_costs()
        self._build_items_costs()

        self._bernoulli_matrix = np.array([[bernoulli(0.5) for _ in range(self._input_dim)]
                                           for _ in range(self._n_items)])

    def _build_cover_matrix(self) -> None:

        # Initialize the availability matrix
        availability = np.zeros(shape=(self._n_items, self._n_sets), dtype=np.int8)

        # For each product (item)...
        for row in range(self._n_items):
            first_col = -1
            second_col = -1

            while first_col == second_col:
                first_col = np.random.randint(low=0, high=self._n_sets, size=1)
                second_col = np.random.randint(low=0, high=self._n_sets, size=1)
                availability[row, first_col] = 1
                availability[row, second_col] = 1

        for col in range(self._n_sets):
            row = np.random.randint(low=0, high=self._n_items, size=1)
            availability[row, col] = 1

        # Check that all the products are available in at least two sets
        available_products = np.sum(availability, axis=1) > 1

        # Check that all the sets have at least one product
        at_least_a_prod = np.sum(availability, axis=0) > 0

        density = np.clip(self._density - np.mean(availability), a_min=0, a_max=None)
        availability += np.random.choice([0, 1], size=(self._n_items, self._n_sets), p=[1 - density, density])
        availability = np.clip(availability, a_min=0, a_max=1)

        assert available_products.all(), "Not all the products are available"
        assert at_least_a_prod.all(), "Not all set cover at least a product"

        self._cover_matrix = availability

    def _build_items_costs(self) -> None:

        items_costs = np.zeros(shape=(self._n_items,))

        for idx in range(self._n_items):
            prod_availability = self._cover_matrix[idx]

            # First of all, we check the costs of the set that cover the current product (possible_prod_cost)
            possible_prod_cost = prod_availability * self._sets_costs
            possible_prod_cost = possible_prod_cost[np.nonzero(possible_prod_cost)]

            # Then we choose the max among them
            max_cost = np.max(possible_prod_cost)
            items_costs[idx] = max_cost * self._penalty

        self._items_costs = items_costs

    def _build_sets_costs(self) -> None:
        self._sets_costs = np.random.randint(low=1, high=100, size=self._n_sets)

    def _generate_instance(self) -> tuple:

        x = np.array([round(random.gauss(0, 1), 3) for _ in range(self._input_dim)])

        b_matmul_x = np.matmul(self._bernoulli_matrix, x)

        demands = []
        for j in range(self._n_items):

            pred = b_matmul_x[j]

            demand = (1 + (pred / math.sqrt(self._input_dim) + 3) ** self._deg)
            demand *= random.uniform(1 - self._multiplicative_noise, 1 + self._multiplicative_noise)
            demand = round(demand + self._additive_noise, 5)

            demands.append(demand)

        demands = np.random.poisson(demands, size=self._n_items)

        params = {COVER_MATRIX: self._cover_matrix, SETS_COSTS: self._sets_costs, ITEMS_COSTS: self._items_costs}

        solution, _ = self._solver.solve(x, demands, params)
        cost = self._solver.compute_metrics(demands, solution, params)[TOTAL_COST]

        return x, demands, solution, cost
