import torch
import torch.nn.functional as F
import time
import random
from data.ass_recall.ar_task import AssociativeRecallTask


# -------------------------------------------------
# Dataloader Class (Handles AssociativeRecall Task)
# -------------------------------------------------
class InContextAssociationDataloader:
    def __init__(
        self,
        vocab_size,
        n_pairs,
        device,
        batch_size,
        block_size,
        iters,
        train_data_size,
        test_data_size,
        # Task-specific distribution arguments
        train_dist_args=None,  # Dict like {'alpha': float, 'burstiness': float, ...}
        test_dist_args=None,  # Dict like {'alpha': float, 'burstiness': float, ...}
    ) -> None:
        """
        Args:
            vocab_size (int): Size of the vocabulary.
            n_pairs (int): Number of in-context pairs (previously called E).
            device: Torch device.
            batch_size, block_size, iters, train_data_size, test_data_size: Config.
            train_dist_args (dict, optional): Distribution parameters for the training task generator.
            test_dist_args (dict, optional): Distribution parameters for the test task generator.
        """
        self.device = torch.device(device)
        self.batch_size = batch_size
        self.block_size = block_size
        self.iters = iters
        self.train_data_size = (
            int(51200000 / block_size) if train_data_size is None else int(train_data_size)
        )
        self.test_data_size = test_data_size
        self.n_pairs = n_pairs
        self.vocab_size = vocab_size

        # Verify block size matches expected size for the task
        expected_block_size = 2 * n_pairs + 2
        if self.block_size != expected_block_size:
            raise ValueError(
                f"Block size must be 2*n_pairs + 2. "
                f"Got block_size={self.block_size}, n_pairs={n_pairs}, expected {expected_block_size}"
            )

        train_args = train_dist_args if train_dist_args is not None else {}
        test_args = test_dist_args if test_dist_args is not None else {}

        # --- Instantiate the Task Generator ---
        print(f"Initializing Task Generators...")

        # Setup task parameters
        train_task_params = {
            "vocab_size": vocab_size,
            "n_pairs": n_pairs,
            "device": device,
            "alpha": train_args.get("alpha", 0.0),
            "burstiness": train_args.get("burstiness", 1.0),
            "p_celebs": train_args.get("p_celebs", 0.0),
            "w_celebs": train_args.get("w_celebs", 1.0),
        }

        test_task_params = {
            "vocab_size": vocab_size,
            "n_pairs": n_pairs,
            "device": device,
            "alpha": test_args.get("alpha", 0.0),
            "burstiness": test_args.get("burstiness", 1.0),
            "p_celebs": test_args.get("p_celebs", 0.0),
            "w_celebs": test_args.get("w_celebs", 1.0),
        }

        self.train_task = AssociativeRecallTask(**train_task_params)
        self.test_task = AssociativeRecallTask(**test_task_params)
        print("Done initializing generators.")

        # --- Generate Data ---
        print(f"Generating fixed test set ({self.test_data_size} sequences)...")
        start_time = time.time()
        self.test_tensor = self.test_task.get_batch(self.test_data_size)
        print(f"Test set generation took {time.time() - start_time:.2f} seconds.")

        self.train_tensor = None
        self._generate_train_data()

    def _generate_train_data(self):
        """Generates or regenerates the training data buffer."""
        print(f"Generating training buffer ({self.train_data_size} sequences)...")
        if self.train_tensor is not None:
            del self.train_tensor
            torch.cuda.empty_cache()
        start_time = time.time()
        self.train_tensor = self.train_task.get_batch(self.train_data_size)
        self.current_train_idx = 0
        print(f"Training buffer generation took {time.time() - start_time:.2f} seconds.")

    def __iter__(self):
        iteration_counter = 0
        while iteration_counter < self.iters:
            if self.current_train_idx + self.batch_size > self.train_data_size:
                print(f"Regenerating training buffer at iteration {iteration_counter}")
                self._generate_train_data()
            start_idx = self.current_train_idx
            end_idx = start_idx + self.batch_size
            batch = self.train_tensor[start_idx:end_idx, :]
            yield batch
            self.current_train_idx = end_idx
            iteration_counter += 1
            if iteration_counter >= self.iters:
                print("Maximum iterations reached.")
                return


# -------------------------------------------------
# Example Usage
# -------------------------------------------------
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    batch_size = 128
    n_pairs = 5
    block_size = 2 * n_pairs + 2  # Must be 2*n_pairs + 2
    iters = 500
    train_data_size = 10240
    test_data_size = 1024
    vocab_size = 50

    # Test with Zipfian distribution for training
    train_params = {
        "vocab_size": vocab_size,
        "alpha": 0.5,  # Zipfian exponent
        "burstiness": 2.0,  # Average query pair appearances
        "p_celebs": 0.1,  # 10% celebrity tokens
        "w_celebs": 0.6,  # 60% of probability mass to celebrities
    }

    # Use uniform distribution for testing
    test_params = {
        "vocab_size": vocab_size,
        "alpha": 0.0,  # Uniform distribution
        "burstiness": 1.0,  # One appearance on average
    }

    try:
        ar_dataloader = InContextAssociationDataloader(
            vocab_size=vocab_size,
            n_pairs=n_pairs,
            device=device,
            batch_size=batch_size,
            block_size=block_size,
            iters=iters,
            train_data_size=train_data_size,
            test_data_size=test_data_size,
            train_dist_args=train_params,
            test_dist_args=test_params,
        )

        print(f"Train Tensor Shape: {ar_dataloader.train_tensor.shape}")
        print(f"Test Tensor Shape: {ar_dataloader.test_tensor.shape}")

        # Example of iterating through the dataloader
        print("\nIterating through dataloader (first 2 batches):")
        for i, batch_data in enumerate(ar_dataloader):
            if i < 2:
                print(f"Batch {i+1} shape: {batch_data.shape}")
                # Show the first sequence to verify structure
                if i == 0:
                    print(f"First sequence structure:\n{batch_data[0]}")
                    # Calculate the number of times query token appears in context
                    query_token = batch_data[0, -2].item()
                    context_tokens = batch_data[0, :-2:2]  # Every other position contains triggers
                    appearances = (context_tokens == query_token).sum().item()
                    print(f"Query token {query_token} appears {appearances} times in the context")
            else:
                break
        print("Finished dataloader iteration example.")

    except Exception as e:
        print(f"Error in dataloader: {str(e)}")
