import torch
import time
import torch.distributions


class AssociativeRecallTask:
    def __init__(
        self, vocab_size, n_pairs, device, alpha=0.0, burstiness=1, p_celebs=0.0, w_celebs=1.0
    ):
        """
        A associative recall task.

        Args:
            vocab_size (int): The size of the vocabulary from which triggers and answers are drawn.
            n_pairs (int): The number of (trigger, answer) pairs in the context.
            device: The torch device to use for tensor operations.
            alpha (float, optional): The exponent for the Zipfian distribution for sampling triggers.
                                     alpha = 0 corresponds to a uniform distribution. Defaults to 0.0.
            burstiness (float, optional): The *average* number of times the query trigger (and its answer)
                                          should appear in the context pairs. Must be >= 1 and <= n_pairs.
                                          Defaults to 1.
            p_celebs (float, optional): Proportion of vocabulary to be considered "celebrities".
                                        Defaults to 0.0.
            w_celebs (float, optional): Weight for celebrity tokens. Defaults to 1.0.
        """
        if not isinstance(vocab_size, int) or vocab_size <= 0:
            raise ValueError("vocab_size must be a positive integer.")
        if not isinstance(n_pairs, int) or n_pairs <= 0:
            raise ValueError("n_pairs must be a positive integer.")
        if n_pairs > vocab_size:
            raise ValueError(f"n_pairs ({n_pairs}) cannot exceed vocab_size ({vocab_size}).")
        if not isinstance(burstiness, (int, float)):
            raise ValueError("burstiness (average) must be a number.")
        if burstiness < 1:
            raise ValueError("burstiness (average) must be >= 1.")
        if burstiness > n_pairs:
            raise ValueError(
                f"burstiness (average) ({burstiness}) cannot exceed n_pairs ({n_pairs})."
            )

        self.vocab_size = vocab_size
        self.n_pairs = n_pairs
        self.device = device
        self.alpha = alpha
        self.burstiness_avg = float(burstiness)
        self.p_celebs = p_celebs
        self.w_celebs = w_celebs

        if self.p_celebs > 0.0:
            self.pi_trigger = self._get_celebrity_dist(
                self.vocab_size, self.p_celebs, self.w_celebs, self.device
            )
        else:
            self.pi_trigger = self._get_zipfian_dist(self.vocab_size, self.alpha, self.device)

        if self.pi_trigger.count_nonzero() < self.n_pairs:
            raise ValueError(
                f"Cannot sample {self.n_pairs} unique triggers from a distribution with "
                f"{self.pi_trigger.count_nonzero()} non-zero elements over vocab_size {self.vocab_size}."
            )

    def _get_zipfian_dist(self, vocab_size, alpha, device):
        """
        Generates a Zipfian distribution or uniform if alpha is 0.
        """
        if vocab_size == 0:
            return torch.empty(0, device=device)
        if alpha == 0:
            return torch.ones(vocab_size, device=device, dtype=torch.float) / vocab_size
        else:
            # Values for k=1, 2, ..., vocab_size
            k_values = torch.arange(1, vocab_size + 1, device=device, dtype=torch.float)
            probs = k_values ** (-alpha)
            return probs / probs.sum()

    def _get_celebrity_dist(self, vocab_size, p_celebs, w_celebs, device):
        """
        Generates a distribution where:
        - The first p_celebs * vocab_size tokens (celebrities) receive a total probability of w_celebs
        - The remaining (1-p_celebs) * vocab_size tokens share the remaining (1-w_celebs) probability

        Args:
            vocab_size: Size of the vocabulary
            p_celebs: Proportion of vocabulary to be considered "celebrities"
            w_celebs: Total probability mass allocated to all celebrities combined
            device: The torch device to use

        Returns:
            A probability distribution over the vocabulary
        """
        if vocab_size == 0:
            return torch.empty(0, device=device)
        if not (0.0 <= p_celebs <= 1.0):
            raise ValueError("p_celebs must be between 0.0 and 1.0")
        if not (0.0 <= w_celebs <= 1.0):
            raise ValueError("w_celebs must be between 0.0 and 1.0")

        num_celebs = int(p_celebs * vocab_size)
        num_non_celebs = vocab_size - num_celebs

        weights = torch.zeros(vocab_size, device=device, dtype=torch.float)

        # Handle edge cases
        if num_celebs == 0 or num_celebs == vocab_size:
            return torch.ones(vocab_size, device=device, dtype=torch.float) / vocab_size

        # Normal case: split probability mass between celebrities and non-celebrities
        weights[:num_celebs] = w_celebs / num_celebs
        weights[num_celebs:] = (1.0 - w_celebs) / num_non_celebs

        return weights

    def get_batch(self, batch_size):
        """
        Generates a batch of sequences for the task.

        Args:
            batch_size (int): The number of sequences to generate in the batch.

        Returns:
            torch.Tensor: A tensor of shape (batch_size, seq_length) containing token indices.
                          seq_length is 2 * n_pairs + 2.
                          The format is [t_0, a_0, ..., t_{n_pairs-1}, a_{n_pairs-1}, query_trigger, target_answer].
        """
        if not isinstance(batch_size, int) or batch_size <= 0:
            raise ValueError("batch_size must be a positive integer.")

        seq_length = 2 * self.n_pairs + 2

        # 1. Directly sample query_trigger from self.pi_trigger
        query_trigger = torch.multinomial(
            self.pi_trigger.unsqueeze(0).expand(batch_size, -1), 1, replacement=True
        ).squeeze(-1)

        # 2. Sample target_answer uniformly from vocabulary
        target_answer = torch.randint(0, self.vocab_size, (batch_size,), device=self.device)

        # 3. Handle the case where n_pairs = 0 (only query and answer)
        if self.n_pairs == 0:
            sequences = torch.zeros(batch_size, 2, dtype=torch.long, device=self.device)
            sequences[:, 0] = query_trigger
            sequences[:, 1] = target_answer
            return sequences

        # 4. Generate context triggers and answers
        context_triggers = torch.full(
            (batch_size, self.n_pairs), -1, dtype=torch.long, device=self.device
        )
        context_answers = torch.full(
            (batch_size, self.n_pairs), -1, dtype=torch.long, device=self.device
        )

        # 5. Determine burstiness (k_b) for each batch item
        n_trials = self.n_pairs - 1
        if self.burstiness_avg <= 1.0:
            k_additional_tensor = torch.zeros(batch_size, dtype=torch.long, device=self.device)
        elif self.burstiness_avg >= self.n_pairs:
            k_additional_tensor = torch.full(
                (batch_size,), n_trials, dtype=torch.long, device=self.device
            )
        else:
            p_success = (self.burstiness_avg - 1.0) / n_trials
            p_success_tensor = torch.clamp(
                torch.tensor(p_success, device=self.device, dtype=torch.float32), 0.0, 1.0
            )
            binomial_dist = torch.distributions.Binomial(n_trials, p_success_tensor)
            k_additional_tensor = binomial_dist.sample((batch_size,)).long()
        k_b = k_additional_tensor + 1

        # 6. Sample 'other' triggers and answers (excluding query_trigger and target_answer)
        # This is done for each item in the batch separately
        other_triggers = torch.zeros(
            batch_size, self.n_pairs - 1, dtype=torch.long, device=self.device
        )
        other_answers = torch.zeros(
            batch_size, self.n_pairs - 1, dtype=torch.long, device=self.device
        )

        if self.n_pairs > 1:
            for i in range(batch_size):
                # Create modified distribution excluding the query_trigger
                modified_probs = self.pi_trigger.clone()
                modified_probs[query_trigger[i]] = 0
                modified_probs = modified_probs / modified_probs.sum()  # Re-normalize

                # Sample n_pairs-1 unique 'other' triggers from this distribution
                other_triggers[i] = torch.multinomial(
                    modified_probs, self.n_pairs - 1, replacement=False
                )

                # Sample n_pairs-1 unique 'other' answers (different from target_answer)
                answer_mask = torch.ones(self.vocab_size, device=self.device, dtype=torch.float)
                answer_mask[target_answer[i]] = 0

                rand_for_answers = torch.rand(self.vocab_size, device=self.device) * answer_mask
                # Get top n_pairs-1 indices from the masked random values
                other_answers[i] = torch.topk(rand_for_answers, self.n_pairs - 1)[1]

            # Shuffle these 'other' pairs for each batch item
            rand_for_shuffle = torch.rand(batch_size, self.n_pairs - 1, device=self.device)
            shuffled_indices = torch.argsort(rand_for_shuffle, dim=1)

            other_triggers = torch.gather(other_triggers, 1, shuffled_indices)
            other_answers = torch.gather(other_answers, 1, shuffled_indices)

        # 7. Place query pairs (bursts) in randomly selected positions
        rand_for_placement = torch.rand(batch_size, self.n_pairs, device=self.device)
        sorted_indices = torch.argsort(rand_for_placement, dim=1)

        for j_burst in range(self.n_pairs):
            is_active = k_b > j_burst  # Rows that need to place their (j_burst+1)-th copy
            if torch.any(is_active):
                active_rows = torch.where(is_active)[0]
                slots = sorted_indices[active_rows, j_burst]

                context_triggers[active_rows, slots] = query_trigger[active_rows]
                context_answers[active_rows, slots] = target_answer[active_rows]

        # 8. Fill remaining slots with 'other' pairs
        if self.n_pairs > 1:
            other_pair_idx = torch.zeros(batch_size, dtype=torch.long, device=self.device)

            for j_slot in range(self.n_pairs):
                is_empty = context_triggers[:, j_slot] == -1

                if torch.any(is_empty):
                    empty_rows = torch.where(is_empty)[0]
                    current_idx = other_pair_idx[empty_rows]

                    # Only fill slots where we have 'other' pairs available
                    valid_idx = current_idx < (self.n_pairs - 1)
                    fillable_rows = empty_rows[valid_idx]
                    indices_to_use = current_idx[valid_idx]

                    if fillable_rows.numel() > 0:
                        context_triggers[fillable_rows, j_slot] = other_triggers[
                            fillable_rows, indices_to_use
                        ]
                        context_answers[fillable_rows, j_slot] = other_answers[
                            fillable_rows, indices_to_use
                        ]
                        other_pair_idx[fillable_rows] += 1

        # 9. Final check
        if (context_triggers == -1).any() or (context_answers == -1).any():
            raise ValueError(
                "Context not fully filled. Check logic for burstiness and other pair placement."
            )

        # 10. Construct the final sequences
        sequences = torch.zeros(batch_size, seq_length, dtype=torch.long, device=self.device)

        for i in range(self.n_pairs):
            sequences[:, 2 * i] = context_triggers[:, i]
            sequences[:, 2 * i + 1] = context_answers[:, i]

        sequences[:, -2] = query_trigger
        sequences[:, -1] = target_answer

        return sequences
