import torch
import time
import random
import numpy as np


# -------------------------------------------------
# Task Class: Generates Single Location Linear Regression Data
# -------------------------------------------------
class SingleLocationLinearRegressionTask:
    def __init__(
        self,
        dimension,
        sequence_length,
        p_repeat,
        burstiness,
        device,
        show_relevant_token=False,
        random_relevant_token_positions=False,
        size_pool_repetition=1,
    ):
        """
        Args:
            dimension (int): Dimension of the vectors.
            sequence_length (int): Length of the input sequences.
            p_repeat (float): Probability of repeating the special token at the first position.
            burstiness (int): Number of positions (including the first) where the first token is repeated.
            device: Torch device.
            show_relevant_token (bool): If True, adds a dimension indicating relevant tokens.
            random_relevant_token_positions (bool): If True, place the 'burstiness' relevant tokens at random positions.
            size_pool_repetition (int): Number of distinct vectors in the pool for repetition. Defaults to 1.
        """
        self.d = dimension
        self.L = sequence_length
        self.p_repeat = p_repeat
        self.B = burstiness
        self.device = torch.device(device)
        self.show_relevant_token = show_relevant_token
        self.random_relevant_token_positions = random_relevant_token_positions
        self.size_pool_repetition = size_pool_repetition

        # --- Parameter Validation ---
        if not 0.0 <= p_repeat <= 1.0:
            raise ValueError("p_repeat must be between 0 and 1.")
        if not isinstance(burstiness, int) or burstiness < 1:
            raise ValueError("burstiness must be an integer >= 1.")
        if burstiness > sequence_length - 1:
            raise ValueError(
                f"burstiness B={burstiness} cannot exceed sequence_length-1 ({sequence_length-1})."
            )
        if sequence_length < 2:
            raise ValueError("sequence_length must be at least 2.")
        if not isinstance(size_pool_repetition, int) or size_pool_repetition < 1:
            raise ValueError("size_pool_repetition must be an integer >= 1.")

        # --- Fixed Parameters for the Dataset ---
        # Sample the fixed linear mapping W ~ N(0, 1/d)
        self.W = torch.randn(self.d, self.d, device=self.device) / np.sqrt(self.d)
        # Sample the fixed pool of repeated_tokens ~ N(0, 1/d)
        self.repeated_token_pool = torch.randn(
            self.size_pool_repetition, self.d, device=self.device
        ) / np.sqrt(self.d)

        # --- Pre-calculate fixed burst indices (if B > 1 and not random) ---
        self.burst_indices = []
        if not self.random_relevant_token_positions:
            possible_indices = list(range(0, self.L))
            self.burst_indices = random.sample(possible_indices, self.B)
            print(f"Fixed Burstiness > 1: Repeating first token at indices {self.burst_indices}")

    def get_batch(self, batch_size):
        """
        Generates a batch of data (X, Y).
        X: (batch_size, sequence_length, dimension) or (batch_size, sequence_length, dimension + 1) if show_relevant_token
        Y: (batch_size, dimension) - Linear map of the first token
        """
        # --- Sample First Token (x_0) ---
        # Decide whether to use the repeated token or a new normal token
        use_repeat_mask = torch.rand(batch_size, 1, device=self.device) < self.p_repeat
        # Sample normal tokens for all sequences initially
        normal_first_tokens = torch.randn(batch_size, self.d, device=self.device) / np.sqrt(self.d)

        # Sample indices from the repetition pool for sequences that will repeat
        repeat_indices = torch.randint(
            0, self.size_pool_repetition, (batch_size,), device=self.device
        )
        # Select the corresponding repeated tokens
        selected_repeated_tokens = self.repeated_token_pool[repeat_indices]  # (batch_size, d)

        # Select based on the mask: use selected repeated token or a new normal token
        x_0 = torch.where(
            use_repeat_mask, selected_repeated_tokens, normal_first_tokens
        )  # (batch_size, d)

        # --- Sample Remaining Tokens (x_1 to x_{L-1}) ---
        # Initialize X with zeros or random noise first
        X_data = torch.randn(batch_size, self.L, self.d, device=self.device) / np.sqrt(self.d)

        # --- Change the values of the tokens at the relevant positions ---
        if self.random_relevant_token_positions:
            all_possible_indices = torch.arange(self.L, device=self.device)
            relevant_indices_tensor = torch.stack(
                [
                    all_possible_indices[torch.randperm(self.L, device=self.device)[: self.B]]
                    for _ in range(batch_size)
                ]
            )
            expanded_x_0 = x_0.unsqueeze(1).expand(-1, self.B, -1)
            expanded_indices = relevant_indices_tensor.unsqueeze(-1).expand(-1, -1, self.d)
            X_data.scatter_(1, expanded_indices, expanded_x_0)
        else:
            for idx in self.burst_indices:
                X_data[:, idx, :] = x_0

        # --- Calculate Output Y ---
        Y = x_0 @ self.W.T  # (batch_size, d) @ (d, d) -> (batch_size, d)

        # --- Add relevant token indicator if requested ---
        if self.show_relevant_token:
            relevant_mask = torch.zeros(
                batch_size, self.L, 1, device=self.device, dtype=X_data.dtype
            )
            if self.random_relevant_token_positions:
                expanded_mask_indices = relevant_indices_tensor.unsqueeze(-1)
                relevant_mask.scatter_(1, expanded_mask_indices, 1.0)
            else:
                for idx in self.burst_indices:
                    relevant_mask[:, idx, 0] = 1.0
            X = torch.cat((X_data, relevant_mask), dim=-1)
        else:
            X = X_data

        return X, Y


# -------------------------------------------------
# Dataloader Class
# -------------------------------------------------
class SingleLocationLinearRegressionDataloader:
    def __init__(
        self,
        task,
        device,
        batch_size,
        iters,
        train_data_size,
        test_data_size,
    ):
        """
        Args:
            task: An initialized SingleLocationLinearRegressionTask instance.
            device: Torch device.
            batch_size, iters, train_data_size, test_data_size: Dataloader config.
        """
        self.device = torch.device(device)
        self.batch_size = batch_size
        self.iters = iters
        # Ensure train_data_size is a multiple of batch_size for simplicity
        self.train_data_size = (train_data_size // batch_size) * batch_size
        self.test_data_size = test_data_size
        # Store the provided task object
        self.task = task

        # --- Generate Fixed Test Set ---
        print(f"Generating fixed test set ({self.test_data_size} sequences)...")
        start_time = time.time()
        self.X_test, self.Y_test = self.task.get_batch(self.test_data_size)
        print(f"Test set generation took {time.time() - start_time:.2f} seconds.")

        # --- Generate Initial Training Buffer ---
        self.X_train = None
        self.Y_train = None
        self._generate_train_data()

    def _generate_train_data(self):
        """Generates or regenerates the training data buffer."""
        print(f"Generating training buffer ({self.train_data_size} sequences)...")
        if self.X_train is not None:
            del self.X_train
            self.X_train = None
        if self.Y_train is not None:
            del self.Y_train
            self.Y_train = None
        if torch.cuda.is_available():
            torch.cuda.empty_cache()  # Clear cache if using GPU

        start_time = time.time()
        self.X_train, self.Y_train = self.task.get_batch(self.train_data_size)
        self.current_train_idx = 0
        print(f"Training buffer generation took {time.time() - start_time:.2f} seconds.")

    def __iter__(self):
        iteration_counter = 0
        while iteration_counter < self.iters:
            if self.current_train_idx >= self.train_data_size:  # Use >= for safety
                print(f"Regenerating training buffer at iteration {iteration_counter}")
                self._generate_train_data()

            start_idx = self.current_train_idx
            # Ensure end_idx doesn't exceed buffer size, though it should align with batch_size
            end_idx = min(start_idx + self.batch_size, self.train_data_size)
            actual_batch_size = end_idx - start_idx

            if actual_batch_size > 0:
                X_batch = self.X_train[start_idx:end_idx]
                Y_batch = self.Y_train[start_idx:end_idx]
                yield X_batch, Y_batch
            else:  # Should not happen if train_data_size is multiple of batch_size
                print("Warning: Reached end of buffer unexpectedly or zero batch size.")
                self._generate_train_data()  # Regenerate and try again in next loop
                continue  # Skip this iteration

            self.current_train_idx = end_idx
            iteration_counter += 1
            if iteration_counter >= self.iters:
                print("Maximum iterations reached.")
                return
