import os
import numpy as np
import pandas as pd
from tqdm import tqdm

from src.utils.probabilities import set_seed


class ToyGenerator:

    def __init__(self, name: str, input_dim: int, output_dim: int, x_lb: float = 0.0, x_up: float = 10.0):

        assert x_up > x_lb

        self._name = name
        self._input_dim = input_dim
        self._output_dim = output_dim
        self._x_lb = x_lb
        self._x_up = x_up

        self._weights = None

    def generate(self, path: str, num_instances: int, seed: int | None) -> None:

        self._pre_generation(path, num_instances, seed)

        rows = []
        for _ in tqdm(range(num_instances), total=num_instances, desc='Data generation'):
            x, y, z, cost = self._generate_instance()
            row = [x, y, z, cost]
            rows.append(row)

        dataframe = pd.DataFrame(rows, columns=["x", "y", "z", "cost"])

        self._save_instance(path, dataframe)

    def _pre_generation(self, path: str, num_instances: int, seed: int | None) -> None:

        assert num_instances >= 1

        if seed is not None:
            set_seed(seed)

        if not os.path.exists(path):
            os.makedirs(path)

        self._build_weights()

    def _build_weights(self) -> None:

        weights = np.random.uniform(low=0.0, high=1.0, size=(self._output_dim, self._input_dim))
        sums = weights.sum(axis=1, keepdims=True)
        self._weights = weights / sums

    def _save_instance(self, path: str, dataframe: pd.DataFrame) -> None:

        dataframe_save_path = os.path.join(path, self._name + ".pkl")
        dataframe.to_pickle(dataframe_save_path)

        weights_save_path = os.path.join(path, self._name + "_weights.npy")
        np.save(weights_save_path, self._weights)

        print("Dataset saved to", dataframe_save_path)

    def _generate_instance(self) -> tuple:

        assert self._weights is not None
        assert len(self._weights.shape) == 2
        assert self._weights.shape[0] == self._output_dim
        assert self._weights.shape[1] == self._input_dim

        x = np.random.uniform(high=self._x_up, low=self._x_lb, size=(self._input_dim, 1))
        y = self._weights @ x
        z = np.array([0.0])
        cost = 0.0

        x = np.squeeze(x)
        y = y[:, 0]

        return x, y, z, cost
