import os
from abc import ABC, abstractmethod

import pandas as pd
from tqdm import tqdm

from src.utils.probabilities import set_seed
from src.utils.strings import *

from src.solvers.solver import Solver


class KnapsackGenerator(ABC):

    def __init__(self, name: str, solver: Solver, input_dim: int, output_dim: int, deg: int = 2,
                 multiplicative_noise: float = 0.5, additive_noise: float = 0, relative_capacity: float = 0.5,
                 correlate_values_and_weights: int = 0, rho: float = 0.5):

        assert 0.0 <= relative_capacity <= 1.0
        assert correlate_values_and_weights in [0, 1, 2]

        self._name = name
        self._solver = solver
        self._input_dim = input_dim
        self._output_dim = output_dim
        self._deg = deg
        self._multiplicative_noise = multiplicative_noise
        self._additive_noise = additive_noise
        self._relative_capacity = relative_capacity
        self._correlate_values_and_weights = correlate_values_and_weights
        self._rho = rho

    def generate(self, path: str, num_instances: int, seed: int | None) -> None:

        self._pre_generation(path, num_instances, seed)

        rows = []
        for _ in tqdm(range(num_instances), total=num_instances, desc='Data generation'):
            instance = self._generate_instance()
            row = [instance[INPUT], instance[VALUES], instance[WEIGHTS],
                   instance[CAPACITY], instance[SOLUTION], instance[COST]]
            rows.append(row)

        dataframe = pd.DataFrame(rows, columns=[INPUT, VALUES, WEIGHTS, CAPACITY, SOLUTION, COST])

        save_path = os.path.join(path, self._name + ".pkl")
        dataframe.to_pickle(save_path)
        print("Dataset saved to", save_path)

    def _pre_generation(self, path: str, num_instances: int, seed: int | None):

        assert num_instances >= 1

        if seed is not None:
            set_seed(seed)

        if not os.path.exists(path):
            os.makedirs(path)

    @abstractmethod
    def _generate_instance(self) -> dict:
        pass
