import numpy as np
import os
import pandas as pd
from tqdm import tqdm
from scipy.stats import uniform, binom, multivariate_normal

from src.solvers.production_planning_solver import ProductionPlanningSolver

from src.utils.strings import *


class ProductionPlanningGenerator:

    """
    A sampler class for production planning problems.

    Problem instances are sampled based on user provided data, assuming that
    a fixed set of customers may or may not request one unit of each available
    product.

    The probability that a customer requires a product follows a classical
    Bernoulli distribution, conditioned on a number of abstract observables.
    As a result, the total demand for each product follows a Binomial
    distribution. All observables are assumed to be uniformly distributed.

    The underproduction and overproduction costs are built according to
    deterministic rules. In particular:

    - The underproduction costs grow linearly over the product id in the
      [0.1, 1] interval
    - The overproduction costs are generated according to a sinusoidal law
      in the [0.1, 1] interval

    This approach ensures that many distinct combination of under- and over-
    production costs are considered.
    """

    def __init__(self, name: str, n_products: int, n_observables: int, capacity: int,
                 use_binomial_distribution: bool = True, n_customers: int = 0, sigma: float = 0.0,
                 max_demands: int = 100, costs_asymmetry: float = 0.5):

        """
        Build a product sampler object.

        Parameters
        ----------
        n_products : int
            The number of distinct product types. This determines the number
            of variables in the optimization problem
        n_observables : int
            The number of observables. This determines the difficulty of the
            prediction problem (which is in any case quite easy)
        capacity : int
            The total number of units of all products that can be manufactured
        use_binomial_distribution: bool
            Whether to use the binomial distribution or the normal distribution to generate demands
        n_customers : int
            The number of customers that may require the products, thus
            determining the demand (used only if use_binomial_distribution is True)
        sigma : float
            The standard deviation for the normal distributions for sampling
            demands (used only if use_binomial_distribution is False)
        max_demands: int
            The maximum number of demands for a product, used to scale the generation output
            (used only if use_binomial_distribution is False)
        costs_asymmetry: float
            Level of asymmetry in underproduction and overproduction costs, between 0.0 (perfect symmetry)
            and 1.0 (highest asymmetry)
        """

        if use_binomial_distribution:
            assert n_customers > 0, "Binomial distribution requires at least 1 customer"
        else:
            assert sigma > 0.0 and max_demands > 0, "Normal distribution requires a positive standard deviation (sigma)"

        assert 0.0 <= costs_asymmetry <= 1.0

        self._name = name
        self._n_products = n_products
        self._n_observables = n_observables
        self._capacity = capacity
        self._use_binomial_distribution = use_binomial_distribution
        self._n_customers = n_customers
        self._sigma = sigma * max_demands
        self._max_demands = max_demands
        self._costs_asymmetry = costs_asymmetry

        # self._underproduction_costs = 0.1 + np.linspace(0, 0.9, n_products)
        # self._overproduction_costs = 0.5 * (1.1 + 0.9 * np.sin(np.linspace(0, 2 * np.pi, n_products)))

        up_avg = 0.1 * costs_asymmetry + 0.5 * (1.0 - costs_asymmetry)
        up_delta = 0.05
        up_low = up_avg - up_delta
        up_high = up_avg + up_delta
        self._underproduction_costs = np.array(np.linspace(up_low, up_high, n_products // 2).tolist() +
                                               np.linspace(1.0 - up_high, 1.0 - up_low, n_products - n_products // 2).tolist())
        self._overproduction_costs = 1.0 - self._underproduction_costs

        self._observable_bounds = [[0, 1]] * n_observables
        self._weights = None

        self._solver = ProductionPlanningSolver()
        self._params = {
            UNDERPRODUCTION_COSTS: self._underproduction_costs,
            OVERPRODUCTION_COSTS: self._overproduction_costs,
            CAPACITY: self._capacity
        }

    def generate(self, path: str, num_instances: int, seed: int | None) -> None:

        np.random.seed(seed)

        self._weights = uniform.rvs(-1, 1, size=(self._n_products, self._n_observables))

        rows = []
        for _ in tqdm(range(num_instances), total=num_instances, desc='Data generation'):
            x, y, z, cost = self._generate_instance()
            row = [x, y, z, cost]
            rows.append(row)

        dataframe = pd.DataFrame(rows, columns=[INPUT, DEMANDS, SOLUTION, COST])

        self._save_instance(path, dataframe)

    def _demand_distribution(self, observables: np.ndarray) -> binom:

        """
        Obtain the demand distribution for a given sample of observables

        Parameters
        ----------
        observables : n_samples x n_observables matrix of floats
            Values for the observables, in multiple scenarios

        Returns
        -------
        distributions : n_samples x n_products binomial or normal distribution
        """

        logits = np.array([self._weights @ observables[i, :] for i in range(observables.shape[0])])
        probs = 1 / (1 + np.exp(-logits))

        if self._use_binomial_distribution:
            dist = binom(n=self._n_customers, p=probs)

        else:
            means = probs * self._max_demands
            dist = [multivariate_normal(mean=mean, cov=self._sigma) for mean in means]

        return dist

    def _demand_mean(self, observables: np.ndarray) -> np.ndarray:

        """
        Obtain the mean of the demand distribution for a given sample of
        observables. This is the best classical point estimate for this
        problem.

        Parameters
        ----------
        observables : n_samples x n_observables matrix of floats
            Values for the observables, in multiple scenarios

        Returns
        -------
        mean : n_samples x n_products array of floats
        """

        return self._demand_distribution(observables).mean()

    def _sample_demands(self, n_samples: int, observables: np.ndarray) -> np.ndarray:

        """
        Sample the demand distribution corresponding to a given sample of
        observables.

        Parameters
        ----------
        observables : n_samples x n_observables matrix of floats
            Values for the observables, in multiple scenarios

        Returns
        -------
        sample : n_products demands, for each observables sample row
        """

        demand_dist = self._demand_distribution(observables)
        if isinstance(demand_dist, list):
            samples = np.array([dist.rvs() for dist in demand_dist])
        else:
            samples = demand_dist.rvs(size=(n_samples, self._n_products))

        samples = (samples * (samples > 0.0)).astype(np.int32)

        return samples

    def _sample_observables(self, n_samples: int) -> np.ndarray:

        """
        Sample the observables

        Parameters
        ----------
        n_samples : int
            the number of samples to draw

        Returns
        -------
        sample : n_observables values, for each requested sample
        """

        obs_rngs = [uniform(b[0], b[1]) for b in self._observable_bounds]
        obs_sample = np.vstack([rng.rvs(n_samples) for rng in obs_rngs]).T

        return obs_sample

    def _sample(self, n_samples: int) -> tuple[np.ndarray, np.ndarray]:

        """
        Jointly sample the observables and the demands

        Parameters
        ----------
        n_samples : int
            the number of samples to draw

        Returns
        -------
        obs_sample : n_observables values, for each requested sample
        demand_sample : n_products values, for each requested sample
        """

        obs_sample = self._sample_observables(n_samples)
        demand_sample = self._sample_demands(n_samples, obs_sample)

        return obs_sample, demand_sample

    def _save_instance(self, path: str, dataframe: pd.DataFrame) -> None:

        dataframe_save_path = os.path.join(path, self._name + ".pkl")
        dataframe.to_pickle(dataframe_save_path)

        u_save_path = os.path.join(path, self._name + "_underproduction_costs.npy")
        np.save(u_save_path, self._underproduction_costs)

        o_save_path = os.path.join(path, self._name + "_overproduction_costs.npy")
        np.save(o_save_path, self._overproduction_costs)

        print("Dataset saved to", dataframe_save_path)

    def _generate_instance(self) -> tuple:

        x, y = self._sample(n_samples=1)
        x = np.squeeze(x)
        y = np.squeeze(y)

        z, _ = self._solver.solve(x, y, self._params)

        metrics = self._solver.compute_metrics(y, z, self._params)
        cost = metrics[TOTAL_COST]

        return x, y, z, cost
