import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import numpy as np
from typing import Optional, List, Tuple, Iterator


class MixedRatioDataloader:
    """
    A dataloader that mixes samples from two datasets with an adjustable ratio.
    """

    def __init__(
        self,
        il_dataset: Dataset,
        rl_dataset: Optional[Dataset] = None,
        batch_size: int = 32,
        rl_ratio: float = 0.0,
        shuffle: bool = True,
        num_workers: int = 0,
        pin_memory: bool = True,
        drop_last: bool = False,
        seed: Optional[int] = None,
    ):
        """
        Initialize the mixed ratio dataloader.

        Args:
            il_dataset: The imitation learning dataset
            rl_dataset: The reinforcement learning dataset (can be None initially)
            batch_size: Total batch size
            rl_ratio: Initial ratio of RL samples in each batch (0.0 to 1.0)
            shuffle: Whether to shuffle the datasets
            num_workers: Number of workers for data loading
            pin_memory: Whether to pin memory for faster GPU transfer
            drop_last: Whether to drop the last incomplete batch
            seed: Random seed for reproducibility
        """
        self.il_dataset = il_dataset
        self.rl_dataset = rl_dataset
        self.batch_size = batch_size
        self._rl_ratio = rl_ratio
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.seed = seed

        if seed is not None:
            np.random.seed(seed)
            torch.manual_seed(seed)

        self._create_dataloaders()

    def _create_dataloaders(self):
        """Create or update the internal dataloaders based on current settings."""
        # Create IL dataloader
        self.il_dataloader = DataLoader(
            self.il_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=self.drop_last,
        )

        # Create RL dataloader if dataset exists
        if self.rl_dataset is not None and len(self.rl_dataset) > 0:
            self.rl_dataloader = DataLoader(
                self.rl_dataset,
                batch_size=self.batch_size,
                shuffle=self.shuffle,
                num_workers=0,  # Allow for easier resampling
                pin_memory=self.pin_memory,
                drop_last=False,
            )
            self.rl_iterator = iter(self.rl_dataloader)
        else:
            self.rl_dataloader = None
            self.rl_iterator = None

    @property
    def rl_ratio(self) -> float:
        """Get the current RL ratio."""
        return self._rl_ratio

    def set_rl_ratio(self, ratio: float):
        """
        Set the ratio of RL samples in each batch.

        Args:
            ratio: New ratio of RL samples in each batch (0.0 to 1.0)
        """
        if not 0 <= ratio <= 1:
            raise ValueError(f"RL ratio must be between 0 and 1, got {ratio}")
        self._rl_ratio = ratio

        print(f"Updated RL ratio to {self._rl_ratio:.3f}; RL batch size: {self.get_rl_batch_size()}")

    def get_rl_batch_size(self) -> int:
        """Calculate the number of RL samples in each batch."""
        return int(self.batch_size * self._rl_ratio)

    def get_il_batch_size(self) -> int:
        """Calculate the number of IL samples in each batch."""
        return self.batch_size - self.get_rl_batch_size()

    def update_rl_dataset(self, new_rl_dataset: Dataset):
        """
        Update the RL dataset.

        Args:
            new_rl_dataset: New RL dataset to use
        """
        self.rl_dataset = new_rl_dataset
        self._create_dataloaders()

    def _get_next_rl_batch(self):
        """Get the next batch from the RL dataset, restarting if needed."""
        if self.rl_dataloader is None or self._rl_ratio == 0:
            return None

        try:
            batch = next(self.rl_iterator)
        except StopIteration:
            # Restart the iterator
            self.rl_iterator = iter(self.rl_dataloader)
            batch = next(self.rl_iterator)

        return batch

    def __len__(self) -> int:
        """Return the number of batches in an epoch."""
        return len(self.il_dataloader)

    def __iter__(self) -> Iterator:
        """Create and return an iterator over mixed batches."""
        il_iterator = iter(self.il_dataloader)

        for il_batch in il_iterator:
            rl_batch_size = self.get_rl_batch_size()
            il_batch_size = self.get_il_batch_size()

            if rl_batch_size == 0 or self.rl_dataset is None or len(self.rl_dataset) == 0:
                yield il_batch
                continue

            rl_batch = self._get_next_rl_batch()
            mixed_batch = {}

            for key in il_batch.keys():
                # Get samples from IL batch
                il_samples = il_batch[key][:il_batch_size]

                if key in rl_batch:
                    # Get samples from RL batch for this key
                    rl_samples = rl_batch[key][:rl_batch_size]

                    # Concatenate IL and RL samples
                    mixed_batch[key] = torch.cat([il_samples, rl_samples], dim=0)
                else:
                    raise ValueError(f"Key {key} not found in RL batch")

            for key in rl_batch.keys():
                if key not in mixed_batch:
                    # mixed_batch[key] = rl_batch[key][:rl_batch_size]
                    raise ValueError(f"Key {key} not found in IL batch")

            yield mixed_batch
