import numpy as np
import torch
from torch.utils.data import Sampler, Dataset
import torch.distributed as dist

from typing import TypeVar, Optional, Iterator
import math

T_co = TypeVar("T_co", covariant=True)
from torch_geometric.graphgym.config import cfg


def chunk(indices, chunk_size):
    return torch.split(torch.tensor(indices), chunk_size)


def distribute_objects_optimized(
    node_index_list, graph_sizes, num_buckets, batch_size, max_unique_graphs
):
    # node_index_list = list of node indices from each graph
    # graph_sizes: list of graph sizes, same length/order as node_index_list

    # Sort objects by size
    sorted_objects = sorted(
        zip(graph_sizes, np.arange(len(graph_sizes)), node_index_list),
        key=lambda x: x[0],
        reverse=True,
    )

    # Initialize buckets
    buckets = [
        {
            "graph_index": [],
            "node_index": [],
            "graph_limit": max_unique_graphs,
            "free": batch_size,
        }
        for idx in range(num_buckets)
    ]

    # Distribute objects
    bucket_index = 0
    direction = 1  # 1 for left-to-right, -1 for right-to-left
    dropped_nodes = 0
    can_add = False
    counter = 0
    for graph_size, graph_idx, nodes_sampled_from_graph in sorted_objects:
        num_nodes = len(nodes_sampled_from_graph)
        counter += 1
        can_add = any(
            bucket["graph_limit"] > 0 and bucket["free"] > 0 for bucket in buckets
        )
        if not can_add:
            dropped_nodes += num_nodes
            continue
        while num_nodes > 0:
            can_add = False
            can_add = any(
                bucket["graph_limit"] > 0 and bucket["free"] > 0 for bucket in buckets
            )
            if not can_add:
                dropped_nodes += num_nodes
                break

            # Find the next bucket where the object can be added
            while (
                buckets[bucket_index]["free"] == 0
                or buckets[bucket_index]["graph_limit"] == 0
            ):  # Skip full buckets or if max allowed graphs reached
                bucket_index += direction
                if bucket_index >= num_buckets or bucket_index < 0:
                    # Change direction when the end or start is reached
                    direction *= -1
                    bucket_index += direction

            # Add the object to the bucket
            num_nodes_to_add = min(num_nodes, buckets[bucket_index]["free"])
            buckets[bucket_index]["graph_index"].extend([graph_idx] * num_nodes_to_add)
            buckets[bucket_index]["node_index"].extend(
                nodes_sampled_from_graph[:num_nodes_to_add]
            )
            buckets[bucket_index]["free"] -= num_nodes_to_add

            num_nodes -= num_nodes_to_add
            nodes_sampled_from_graph = nodes_sampled_from_graph[num_nodes_to_add:]
            # print(f'adding to bucket {bucket_index}, current limit is {buckets[bucket_index]["graph_limit"]}')
            buckets[bucket_index]["graph_limit"] -= 1
            bucket_index += direction
            if bucket_index >= num_buckets or bucket_index < 0:
                # Change direction when the end or start is reached
                direction *= -1
                bucket_index += direction

    if cfg.rank == 0:
        print("Number of dropped nodes: ", dropped_nodes)

    return buckets


class DistributedSSSampler(Sampler[T_co]):
    def __init__(
        self,
        dataset: Dataset,
        graph_ids: torch.Tensor,
        unique_graph_id: torch.Tensor,
        num_nodes: torch.Tensor,
        ratio_in_epoch: torch.Tensor,
        seq_len: torch.Tensor,
        max_unique_graphs: int,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
        batch_size: int = 1,
        seed: int = 0,
        accum_gradient_steps: int = 1,
        drop_last: bool = True,
    ) -> None:

        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1)
            )
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last
        self.max_unique_graphs = max_unique_graphs
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.

        # self.num_sample_bucket = self.num_replicas * batch_size * accum_gradient_steps
        self.shuffle = shuffle
        self.seed = seed
        self.batch_size = batch_size
        self.graph_ids = graph_ids
        self.num_nodes = num_nodes
        self.unique_graph_id = unique_graph_id
        self.accum_gradient_steps = accum_gradient_steps
        self.num_buckets = self.num_replicas * self.accum_gradient_steps
        self.ratio_in_epoch = ratio_in_epoch
        self.seq_len = seq_len

        self.num_sample_bucket = self.num_replicas * accum_gradient_steps
        indices = self.sample_indices()
        dataset_len = len(indices)
        if self.drop_last and dataset_len % self.num_sample_bucket != 0:
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (dataset_len - self.num_sample_bucket) / self.num_sample_bucket
            )
        else:
            self.num_samples = math.ceil(dataset_len / self.num_sample_bucket)
        self.total_size = self.num_samples * self.num_sample_bucket
        print(f"Total size: {self.total_size}")
        print(f"Dataset size: {dataset_len}")

    def sample_indices(self):
        # Initialize the generator with a seed based on the epoch and a fixed seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)

        # Create a list of indices
        indices = torch.arange(len(self.dataset))
        start_indices = torch.cumsum(torch.tensor(self.seq_len), dim=0) - self.seq_len

        # Reset sampled indices
        self.sampled_indices = []

        # Sample indices for each sequence length based on ratio
        for start, length, ratio in zip(
            start_indices, self.seq_len, self.ratio_in_epoch
        ):
            start = int(start)
            end = int(start + length)
            num_samples = int((ratio * length).floor())

            # oversampling
            if length > 0:
                if ratio > 1:
                    # For oversampling, repeat indices as needed to reach the desired number
                    full_batches = int((num_samples // length).item())
                    remainder = int((num_samples % length).item())
                    sampled_indices = (
                        torch.cat(
                            [
                                torch.randperm(end - start, generator=g)
                                for _ in range(full_batches)
                            ]
                            + [torch.randperm(end - start, generator=g)[:remainder]]
                        )
                        + start
                    )
                    self.sampled_indices.extend(sampled_indices.tolist())
                else:

                    sampled_indices = (
                        torch.randperm(end - start, generator=g)[:num_samples] + start
                    )
                    self.sampled_indices.extend(sampled_indices.tolist())
            else:
                # print("0 node labels present so passing the dataset")
                continue

        # Optionally shuffle the collected indices
        self.sampled_indices = torch.tensor(self.sampled_indices)
        self.sampled_indices = self.sampled_indices[
            torch.randperm(len(self.sampled_indices), generator=g)
        ].tolist()

        return self.sampled_indices

    def __iter__(self) -> Iterator[T_co]:

        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = self.sample_indices()
            # indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]

        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]
        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[
                    :padding_size
                ]
        else:
            # remove tail of data to make it evenly divisible.
            # indices = indices
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        graph_sizes_id = self.num_nodes[indices]
        graph_unique_id = self.unique_graph_id[indices]
        # make the big batch
        batches_indices = chunk(indices, self.batch_size * self.num_buckets)
        batches_graph_id = chunk(graph_sizes_id, self.batch_size * self.num_buckets)
        batches_graph_unique_id = chunk(
            graph_unique_id, self.batch_size * self.num_buckets
        )

        for iter, (batch_graph_id, batch_index, batch_graph_unique_id) in enumerate(
            zip(batches_graph_id, batches_indices, batches_graph_unique_id)
        ):
            if len(batch_graph_id) < self.num_buckets * self.batch_size:
                batch_size = max(1, len(batch_graph_id) // self.num_buckets)
            else:
                batch_size = self.batch_size

            # Get unique graph IDs and their inverse indices
            unique_graph_id, inverse_idx = torch.unique(
                batch_graph_unique_id, return_inverse=True
            )

            # Initialize a tensor to store the first occurrence indices
            first_occurrence = torch.full(
                (unique_graph_id.size(0),), -1, dtype=torch.long
            )

            # Find the first occurrence indices
            for i in range(inverse_idx.size(0)):
                if first_occurrence[inverse_idx[i]] == -1:
                    first_occurrence[inverse_idx[i]] = i

            grouped_indices = [[] for _ in range(len(unique_graph_id))]

            # Efficiently group indices
            for i, idx in enumerate(inverse_idx):
                grouped_indices[idx].append(batch_index[i].item())

            # Convert lists to tensors for consistency
            grouped_indices = [torch.tensor(group) for group in grouped_indices]

            unique_graph_sizes = batch_graph_id[first_occurrence]

            result = distribute_objects_optimized(
                grouped_indices,
                unique_graph_sizes,
                self.num_buckets,
                batch_size,
                self.max_unique_graphs,
            )

            for acc_step in range(self.accum_gradient_steps):
                if (
                    result[self.rank + acc_step * self.num_replicas]["node_index"]
                ) == 0:
                    print("no batch")
                yield result[self.rank + acc_step * self.num_replicas]["node_index"]

    def set_epoch(self, epoch: int) -> None:
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.
        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch
