import torch
import itertools
import numpy as np

import torch
import numpy as np
from typing import Optional, List, Union
from torch.utils.data import DataLoader, Subset
import math


class ScheduledILDataloader:
    """
    A dataloader that manages scheduling between IL and RL by controlling the amount of IL data used.
    It ensures no overlap between subsets of data used in different phases and tracks usage.
    """

    def __init__(
        self,
        dataset: torch.utils.data.Dataset,
        batch_size: int,
        schedule: Union[List[float], str] = "decay",
        initial_percentage: float = 0.5,
        decay_factor: float = 0.8,
        min_percentage: float = 0.01,
        num_workers: int = 0,
        shuffle: bool = True,
        pin_memory: bool = True,
        **kwargs
    ):
        """
        Args:
            dataset: The full demonstration dataset
            batch_size: Batch size for the dataloader
            schedule: Either "decay" for automatic decay or list of percentages to use
            initial_percentage: Starting percentage of data to use if using decay
            decay_factor: Factor to multiply percentage by each round if using decay
            min_percentage: Minimum percentage of data to use if using decay
            num_workers: Number of workers for the dataloader
            shuffle: Whether to shuffle the data
            pin_memory: Whether to pin memory
            **kwargs: Additional arguments to pass to the DataLoader
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.pin_memory = pin_memory
        self.dataloader_kwargs = kwargs

        # Generate the schedule
        if isinstance(schedule, str) and schedule == "decay":
            self.schedule = self._generate_decay_schedule(
                initial_percentage, decay_factor, min_percentage
            )
        else:
            self.schedule = schedule
            print(self.schedule)

        # Initialize tracking variables
        self.total_samples = len(dataset)
        self.current_schedule_idx = 0
        self.used_indices = set()
        self.current_dataloader = None
        self.remaining_indices = set(range(self.total_samples))

        # Create first dataloader
        self._create_new_dataloader()

    def _generate_decay_schedule(
        self, initial: float, decay: float, minimum: float
    ) -> List[float]:
        """Generate a decay schedule starting from initial value until minimum."""
        schedule = []
        current = initial
        while current >= minimum:
            schedule.append(current)
            current = current * decay
        return schedule

    def _create_new_dataloader(self) -> Optional[DataLoader]:
        """Create a new dataloader with unused samples based on current schedule."""
        if self.current_schedule_idx >= len(self.schedule):
            # If we've used all schedules, reset tracking and start over
            self.current_schedule_idx = 0
            self.used_indices = set()
            self.remaining_indices = set(range(self.total_samples))

        # Calculate number of samples needed for this round
        current_percentage = self.schedule[self.current_schedule_idx]
        num_samples = int(self.total_samples * current_percentage)

        # If we don't have enough remaining samples, reset
        if len(self.remaining_indices) < num_samples:
            self.used_indices = set()
            self.remaining_indices = set(range(self.total_samples))

        # Randomly select indices from remaining samples
        remaining_indices_list = list(self.remaining_indices)
        selected_indices = np.random.choice(
            remaining_indices_list,
            size=num_samples,
            replace=False
        )

        # Update tracking sets
        selected_indices_set = set(selected_indices)
        self.used_indices.update(selected_indices_set)
        self.remaining_indices -= selected_indices_set

        # Create and return the new dataloader
        subset = Subset(self.dataset, selected_indices)
        self.current_dataloader = FixedStepsDataloader(
            subset,
            n_batches=math.ceil(len(subset) / self.batch_size),
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=self.shuffle,
            pin_memory=self.pin_memory,
            **self.dataloader_kwargs
        )

        return self.current_dataloader

    def get_next_dataloader(self) -> DataLoader:
        """Get the next dataloader in the schedule."""
        self.current_schedule_idx += 1
        return self._create_new_dataloader()

    @property
    def current_percentage(self) -> float:
        """Get the current percentage of data being used."""
        return self.schedule[self.current_schedule_idx]

    def get_current_dataloader(self) -> DataLoader:
        """Get the current dataloader."""
        return self.current_dataloader

    def __iter__(self):
        """Iterator for the current dataloader."""
        return iter(self.current_dataloader)

    def __len__(self):
        """Length of the current dataloader."""
        return len(self.current_dataloader)


class FixedStepsDataloader(torch.utils.data.DataLoader):
    """
    Dataloader that always yields a fixed number of batches.
    If requested number of batches is smaller than available -> return a random subset
    If requested number is larger than available -> cycle through (like a new epoch, random order every time)
    """

    def __init__(self, *args, n_batches, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_batches = n_batches

    def __iter__(self):
        endless_dataloader = itertools.cycle(super().__iter__())
        for _ in range(self.n_batches):
            yield next(endless_dataloader)

    def __len__(self):
        return self.n_batches


class EndlessDataloader(torch.utils.data.DataLoader):
    """
    Dataloader that cycles through the dataset indefinitely.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __iter__(self):
        endless_dataloader = itertools.cycle(super().__iter__())
        for batch in endless_dataloader:
            yield batch

    def __len__(self):
        return float("inf")


class WeightedDataLoader:
    # Thanks to Lirui Wang for this code
    def __init__(self, dataloaders, weight_type="root"):
        """
        :param dataloaders: list of pytorch dataloaders
        :param weight_type: type of weighting, e.g., "square_root"
        """
        self.dataloaders = dataloaders
        if weight_type == "root":
            datasizes = [len(d) for d in dataloaders]
            datasizes = np.power(datasizes, 1.0 / 3)  # np.sqrt(datasizes)
            weights = datasizes / np.sum(datasizes)
            self.weights = weights
        else:
            print(f"weight type {weight_type} not defined")

        self.loader_iters = [iter(dataloader) for dataloader in self.dataloaders]

    def __iter__(self):
        return self

    def __next__(self):
        # Choose a dataloader based on weights
        chosen_dataloader_idx = np.random.choice(len(self.dataloaders), p=self.weights)
        chosen_loader_iter = self.loader_iters[chosen_dataloader_idx]
        try:
            data = next(chosen_loader_iter)
            return data
        except StopIteration:
            # Handle case where a dataloader is exhausted. Reinitialize the iterator.
            self.loader_iters[chosen_dataloader_idx] = iter(
                self.dataloaders[chosen_dataloader_idx]
            )
            return self.__next__()

    def __len__(self):
        return sum([len(dataloader) for dataloader in self.dataloaders])
