import torch as th
from typing import *

from common.scm import *

TensorDict = Dict[str, th.Tensor]


class MultinomialObservedSampler:
    def __init__(self,
                 scm: TensorSCM,
                 observed_proportion: float = 1.0,
                 ) -> None:
        self._scm = scm
        self._observed_proportion = observed_proportion
        self._endos = scm.endogenous_variables
        self._observed_variables = int(
            len(self._endos) * self._observed_proportion
        )

    def sample(self, h: None, j: None, U: TensorDict) -> Set[Any]:
        contained = th.zeros((len(self._endos), ),
                             device=self._scm.device).bool()
        ensures = th.multinomial(th.ones(len(self._endos)),
                                 num_samples=self._observed_variables,
                                 replacement=False)
        contained[ensures] = True
        return set([self._endos[i] for i in range(len(self._endos)) if contained[i]])

    def batched_sample(self, h: th.Tensor, j: th.Tensor, u: th.Tensor) -> th.Tensor:
        contained = th.zeros((u.size(0), len(self._endos)),
                             device=self._scm.device).bool()
        batch_idcs = th.arange(u.size(0))[:, None].expand(-1,
                                                          self._observed_variables)
        ensures = th.multinomial(th.ones(len(self._endos)).expand(u.size(0), -1),
                                 num_samples=self._observed_variables,
                                 replacement=False)
        contained[batch_idcs, ensures] = True
        return contained
