import torch
import torch.nn as nn


class EventSampler(nn.Module):
    """Event Sequence Sampler based on thinning algorithm, which corresponds to Algorithm 2 of
    The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process,
    https://arxiv.org/abs/1612.09328.
    """

    def __init__(self, num_sample, num_exp, over_sample_rate, num_samples_boundary, dtime_max, patience_counter,
                 device):
        """Initialize the event sampler.

        Args:
            num_sample (int): number of sampled next event times via thinning algo for computing predictions.
            num_exp (int): number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
            over_sample_rate (float): multiplier for the intensity up bound.
            num_samples_boundary (int): number of sampled event times to compute the boundary of the intensity.
            dtime_max (float): max value of delta times in sampling
            patience_counter (int): the maximum iteration used in adaptive thinning.
            device (torch.device): torch device index to select.
        """
        super(EventSampler, self).__init__()
        self.num_sample = num_sample
        self.num_exp = num_exp
        self.over_sample_rate = over_sample_rate
        self.num_samples_boundary = num_samples_boundary
        self.dtime_max = dtime_max
        self.patience_counter = patience_counter
        self.device = device

    def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, intensity_fn,
                                      compute_last_step_only, dt_boundary=None):
        """Compute the upper bound of intensity at each event timestamp.

        Args:
            time_seq (tensor): [batch_size, seq_len], timestamp seqs.
            time_delta_seq (tensor): [batch_size, seq_len], time delta seqs.
            event_seq (tensor): [batch_size, seq_len], event type seqs.
            intensity_fn (fn): a function that computes the intensity.
            compute_last_step_only (bool): whether to compute the last time step only.

        Returns:
            tensor: [batch_size, seq_len]
        """
        batch_size, seq_len = time_seq.size()

        # [1, 1, num_samples_boundary]
        time_for_bound_sampled = torch.linspace(start=0.0,
                                                end=1.0,
                                                steps=self.num_samples_boundary,
                                                device=self.device)[None, None, :]

        # [batch_size, seq_len, num_samples_boundary]  # TODO: why this check
        if dt_boundary is None:
            dtime_for_bound_sampled = time_delta_seq[:, :, None] * time_for_bound_sampled
        else:
            dtime_for_bound_sampled = dt_boundary[:, :, None] * time_for_bound_sampled

        # [batch_size, seq_len, num_samples_boundary, num_marks]
        intensities_for_bound = intensity_fn(time_seq,
                                             time_delta_seq,
                                             event_seq,
                                             sample_dtimes=dtime_for_bound_sampled,
                                             compute_last_step_only=compute_last_step_only)

        # [batch_size, seq_len]
        bounds = intensities_for_bound.sum(dim=-1).max(dim=-1)[0] * self.over_sample_rate

        return bounds

    def sample_exp_distribution(self, sample_rate):
        """Sample an exponential distribution.

        Args:
            sample_rate (tensor): [batch_size, seq_len], intensity rate.

        Returns:
            tensor: [batch_size, seq_len, num_exp], exp numbers at each event timestamp.
        """

        batch_size, seq_len = sample_rate.size()

        # For fast approximation, we reuse the rnd for all samples
        # [batch_size, seq_len, num_sample, num_exp]
        exp_numbers = torch.empty(size=[batch_size, seq_len, self.num_sample, self.num_exp],
                                    dtype=torch.float32,
                                    device=self.device)

        # [batch_size, seq_len, num_exp]
        # exp_numbers.exponential_(1.0)
        exp_numbers.exponential_(1.0)

        # [batch_size, seq_len, num_exp]
        # exp_numbers = torch.tile(exp_numbers, [1, 1, self.num_sample, 1])

        # [batch_size, seq_len, num_exp]
        # div by sample_rate is equivalent to exp(sample_rate),
        # see https://en.wikipedia.org/wiki/Exponential_distribution
        exp_numbers = exp_numbers / sample_rate[:, :, None, None]

        return exp_numbers

    def sample_uniform_distribution(self, intensity_upper_bound):
        """Sample an uniform distribution

        Args:
            intensity_upper_bound (tensor): upper bound intensity computed in the previous step.

        Returns:
            tensor: [batch_size, seq_len, num_sample, num_exp]
        """
        batch_size, seq_len = intensity_upper_bound.size()

        unif_numbers = torch.empty(size=[batch_size, seq_len, self.num_sample, self.num_exp],
                                   dtype=torch.float32,
                                   device=self.device)
        unif_numbers.uniform_(0.0, 1.0)

        return unif_numbers

    def sample_accept(self, unif_numbers, sample_rate, total_intensities):
        """Do the sample-accept process.

        For each parallel draw, find its min criterion： if that < 1.0, the 1st (i.e. smallest) sampled time
        with cri < 1.0 is accepted; if none is accepted, use boundary / maxsampletime for that draw

        Args:
            unif_numbers (tensor): [batch_size, max_len, num_sample, num_exp], sampled uniform random number.
            sample_rate (tensor): [batch_size, max_len], sample rate (intensity).
            total_intensities (tensor): [batch_size, seq_len, num_sample, num_exp]

        Returns:
            list: two tensors,
            criterion, [batch_size, max_len, num_sample, num_exp]
            who_has_accepted_times, [batch_size, max_len, num_sample]
        """

        # [batch_size, max_len, num_sample, num_exp]
        criterion = unif_numbers * sample_rate[:, :, None, None] / total_intensities

        # [batch_size, max_len, num_sample]
        min_cri_each_draw, _ = criterion.min(dim=-1)

        # find out unif_numbers * sample_rate < intensity
        # [batch_size, max_len, num_sample]
        who_has_accepted_times = min_cri_each_draw < 1.0

        # This is equivalent to doing (criterion < 1.0).any(dim=-1) but is faster

        return criterion, who_has_accepted_times

    def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_boundary,
                                intensity_fn, compute_last_step_only=False):
        """Compute next event time based on Thinning algorithm.

        Args:
            time_seq (tensor): [batch_size, seq_len], timestamp seqs.
            time_delta_seq (tensor): [batch_size, seq_len], time delta seqs.
            event_seq (tensor): [batch_size, seq_len], event type seqs.
            dtime_boundary (tensor): [batch_size, seq_len], dtime upper bound.
            intensity_fn (fn): a function to compute the intensity.
            compute_last_step_only (bool, optional): whether to compute last event timestep only. Defaults to False.

        Returns:
            tuple: next event time prediction and weight.
        """
        dtime_boundary_ = dtime_boundary[:, -1:] if compute_last_step_only else dtime_boundary  # [batch_size, seq_len]
        sampled_dtimes = torch.tile(dtime_boundary_[..., None], [1, 1, self.num_sample])  # [batch_size, seq_len, num_samples]
        accepted_dtimes = torch.zeros(sampled_dtimes.shape, device=self.device, dtype=torch.bool)  # [batch_size, seq_len, num_samples]
        weights = torch.ones_like(sampled_dtimes, device=self.device) / self.num_sample

        batch_size, seq_len, _ = sampled_dtimes.shape

        current_base_dt = torch.zeros_like(sampled_dtimes, device=self.device)  # [batch_size, seq_len, num_samples]

        # 1. compute the upper bound of the intensity at each timestamp
        # the last event has no label (no next event), so we drop it
        # [batch_size, seq_len=max_len - 1]
        intensity_upper_bound = self.compute_intensity_upper_bound(
            time_seq,
            time_delta_seq,
            event_seq,
            intensity_fn,
            compute_last_step_only,
            dtime_boundary_,
        )

        batch_idx = torch.arange(batch_size, device=self.device)  # Used to remove sequences from computation that have finished all samples for the entire sequence

        while (accepted_dtimes.shape[0] > 0) and (not accepted_dtimes.all()):
            # 2. draw exp distribution with intensity = intensity_upper_bound,
            # We draw them independently for each sample for each event
            # [batch_size, seq_len, num_sample, num_exp]
            exp_numbers = self.sample_exp_distribution(intensity_upper_bound).cumsum(dim=-1) + current_base_dt[..., None]
            exceeds_boundary = (exp_numbers > dtime_boundary_[batch_idx][..., None, None])
        
            # 3. compute intensity at sampled times from exp distribution
            # [batch_size, seq_len, num_sample, num_exp, event_num]
            intensities_at_sampled_times = intensity_fn(
                time_seq[batch_idx],
                time_delta_seq[batch_idx],
                event_seq[batch_idx],
                exp_numbers.view(batch_idx.shape[0], seq_len, -1),  # [batch_size, seq_len, num_sample * num_exp]
                compute_last_step_only=compute_last_step_only,
            )
            intensities_at_sampled_times = intensities_at_sampled_times.view(batch_idx.shape[0], seq_len, self.num_sample, self.num_exp, -1)

            # [batch_size, seq_len, num_sample, num_exp]
            total_intensities = intensities_at_sampled_times.sum(dim=-1)  # sum over number of marks

            # 4. draw uniform distribution
            # [batch_size, seq_len, num_sample, num_exp]
            unif_numbers = self.sample_uniform_distribution(intensity_upper_bound)

            # 5. find out accepted intensities
            acc_ratio = total_intensities / intensity_upper_bound[:, :, None, None]
            not_rejected = unif_numbers < acc_ratio
            any_violations = ((acc_ratio > 1.0) & (~exceeds_boundary)).any(dim=-1).any(dim=-1)
            # Accept a candidate time t iff:
            # - passes the initial rejection step, total_intensity(t) / intensity_upper_bound > u for u~Unif(0,1)
            # - t < right window (dt_boundary)
            # - total_intensity(t) < intensity_upper_bound (for all samples!)
            # - there is not already an accepted candidate time from prior iteration
            accepted_times = not_rejected & (~exceeds_boundary) & (~any_violations[..., None, None]) & (~accepted_dtimes[..., None])

            if accepted_times.any():
                sample_accepted = accepted_times.any(dim=-1) # [batch_size, seq_len, num_samples]

                # since this is 1s and 0s, argmax will return first 1, which corresponds to earliest since they are in sorted order
                earliest_accepted_time_idx = accepted_times.int().argmax(dim=-1)  # [batch_size, seq_len, num_sample]
                
                sampled_dtimes_mask = torch.zeros_like(sampled_dtimes, dtype=torch.bool)
                sampled_dtimes_mask[batch_idx] = sample_accepted
                sampled_dtimes[sampled_dtimes_mask] = torch.gather(exp_numbers, dim=-1, index=earliest_accepted_time_idx[..., None]).squeeze(-1)[sample_accepted]  # Must filter in case a given sample did not accept any times this iteration
                accepted_dtimes = accepted_dtimes | sample_accepted


            # Set sample as accepted if we have sampled past the boundary
            accepted_dtimes = accepted_dtimes | exceeds_boundary.all(dim=-1)

            # Fix upper bound for next iteration
            if any_violations.any():
                intensity_upper_bound[any_violations] = 2. * total_intensities[any_violations].max(dim=-1).values.max(dim=-1).values

                current_base_dt = torch.where(
                    any_violations[..., None],
                    current_base_dt*0.0,  # reset progress to earliest sample for a given sequence
                    exp_numbers.max(dim=-1).values,
                )

                accepted_dtimes[any_violations] = False
                sampled_dtimes_mask = torch.zeros_like(sampled_dtimes, dtype=torch.bool)
                sampled_dtimes_mask[batch_idx] = any_violations[..., None]
                sampled_dtimes[sampled_dtimes_mask] = dtime_boundary_[..., None].expand(-1, -1, self.num_sample)[sampled_dtimes_mask]
            else:
                current_base_dt = exp_numbers.max(dim=-1).values

            # Reduce batch size
            seq_done = accepted_dtimes.all(dim=-1).all(dim=-1)
            if seq_done.any():
                seq_not_done = ~seq_done
                current_base_dt = current_base_dt[seq_not_done]
                intensity_upper_bound = intensity_upper_bound[seq_not_done]
                accepted_dtimes = accepted_dtimes[seq_not_done]
                batch_idx = batch_idx[seq_not_done]

        # add a upper bound here in case it explodes, e.g., in ODE models
        # return sampled_dtimes.clamp(max=1e5), weights  # TODO: check ODE models

        return sampled_dtimes, weights