import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
import math
from typing import TypeVar, Optional, Iterator
from torch.utils.data import Sampler, Dataset
import torch.distributed as dist
from tqdm import tqdm

T_co = TypeVar("T_co", covariant=True)
from torch_geometric.graphgym.config import cfg


def chunk(indices, chunk_size):
    return torch.split(torch.tensor(indices), chunk_size)


class UGTBatchSampler(Sampler[T_co]):
    def __init__(
        self,
        dataset: Dataset,
        graph_ids: torch.Tensor,
        num_nodes: torch.Tensor,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
        batch_size: int = 1,
        seed: int = 0,
        drop_last: bool = True,
    ) -> None:

        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1)
            )
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed
        self.batch_size = batch_size
        self.graph_ids = graph_ids
        self.num_nodes = num_nodes

    def __iter__(self) -> Iterator[T_co]:

        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[
                    :padding_size
                ]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        reordered_graph_id = self.graph_ids[indices]

        # make unsorted batches
        batches_indices = chunk(indices, self.batch_size * self.num_replicas)
        batches_graph_id = chunk(
            reordered_graph_id, self.batch_size * self.num_replicas
        )

        # sort based on graph_id
        # make batches for each rank
        batches_array = [[] for i in range(self.num_replicas)]
        for batch_num, (batch_graph_id, batch_index) in enumerate(
            zip(batches_graph_id, batches_indices)
        ):
            batch_indexes_order = torch.argsort(batch_graph_id)

            chunked_batch = list(
                chunk(batch_index[batch_indexes_order], self.batch_size)
            )

            # yield chunked_batch[self.rank]
            for gpu_chunk, chunked_indices in enumerate(chunked_batch):
                batches_array[gpu_chunk].append(chunked_indices)
        bached_gpu = batches_array[self.rank]

        return iter(bached_gpu)

    def set_epoch(self, epoch: int) -> None:
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.
        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch


def distribute_objects_optimized(
    node_index_list, graph_sizes, num_buckets, batch_size, max_unique_graphs
):
    # node_index_list = list of node indices from each graph
    # graph_sizes: list of graph sizes, same length/order as node_index_list

    # Sort objects by size
    sorted_objects = sorted(
        zip(graph_sizes, np.arange(len(graph_sizes)), node_index_list),
        key=lambda x: x[0],
        reverse=True,
    )

    # Initialize buckets
    buckets = [
        {
            "graph_index": [],
            "node_index": [],
            "graph_limit": max_unique_graphs,
            "free": batch_size,
        }
        for _ in range(num_buckets)
    ]

    # Distribute objects
    bucket_index = 0
    direction = 1  # 1 for left-to-right, -1 for right-to-left
    dropped_nodes = 0
    can_add = False
    for graph_size, graph_idx, nodes_sampled_from_graph in sorted_objects:

        num_nodes = len(nodes_sampled_from_graph)

        can_add = any(
            bucket["graph_limit"] > 0 and bucket["free"] > 0 for bucket in buckets
        )
        if not can_add:
            dropped_nodes += num_nodes
            break  # breakout from loop as no buckets are free

        while num_nodes > 0:
            can_add = False
            can_add = any(
                bucket["graph_limit"] > 0 and bucket["free"] > 0 for bucket in buckets
            )
            if not can_add:
                dropped_nodes += num_nodes
                break

            # Find the next bucket where the object can be added
            while (
                buckets[bucket_index]["free"] == 0
                or buckets[bucket_index]["graph_limit"] == 0
            ):  # Skip full buckets or if max allowed graphs reached
                bucket_index += direction
                if bucket_index >= num_buckets or bucket_index < 0:
                    # Change direction when the end or start is reached
                    direction *= -1
                    bucket_index += direction

            # Add the object to the bucket
            num_nodes_to_add = min(num_nodes, buckets[bucket_index]["free"])
            buckets[bucket_index]["graph_index"].extend([graph_idx] * num_nodes_to_add)
            buckets[bucket_index]["node_index"].extend(
                nodes_sampled_from_graph[:num_nodes_to_add]
            )
            buckets[bucket_index]["free"] -= num_nodes_to_add

            num_nodes -= num_nodes_to_add
            nodes_sampled_from_graph = nodes_sampled_from_graph[num_nodes_to_add:]
            buckets[bucket_index]["graph_limit"] -= 1
            bucket_index += direction
            if bucket_index >= num_buckets or bucket_index < 0:
                # Change direction when the end or start is reached
                direction *= -1
                bucket_index += direction
    # if cfg.rank == 0:
    #     print('Number of dropped nodes: ', dropped_nodes)

    return buckets


# def distribute_objects_optimized(node_index_list, graph_sizes, num_buckets, batch_size, max_unique_graphs):
#     # node_index_list = list of node indices from each graph
#     # graph_sizes: list of graph sizes, same length/order as node_index_list

#     # Sort objects by size
#     sorted_objects = sorted(
#         zip(graph_sizes, np.arange(len(graph_sizes)), node_index_list),
#         key=lambda x: x[0],
#         reverse=True,
#     )

#     # Initialize buckets
#     buckets = [
#         {"graph_index": [], "node_index": [], "free": batch_size, "unique_graphs": set()}
#         for _ in range(num_buckets)
#     ]


#     for graph_size, graph_idx, nodes_sampled_from_graph in sorted_objects:
#         num_nodes = len(nodes_sampled_from_graph)
#         bucket_index = 0
#         direction = 1  # 1 for left-to-right, -1 for right-to-left

#         while num_nodes > 0:
#             # Find the next bucket where the object can be added
#             while True:
#                 if buckets[bucket_index]["free"] > 0 and (
#                     len(buckets[bucket_index]["unique_graphs"]) < max_unique_graphs or
#                     graph_idx in buckets[bucket_index]["unique_graphs"]
#                 ):
#                     break
#                 bucket_index += direction
#                 if bucket_index >= num_buckets or bucket_index < 0:
#                     # Change direction when the end or start is reached
#                     direction *= -1
#                     bucket_index += direction

#             # Add the object to the bucket
#             num_nodes_to_add = min(num_nodes, buckets[bucket_index]["free"])
#             buckets[bucket_index]["graph_index"].extend([graph_idx] * num_nodes_to_add)
#             buckets[bucket_index]["node_index"].extend(
#                 nodes_sampled_from_graph[:num_nodes_to_add]
#             )
#             buckets[bucket_index]["unique_graphs"].add(graph_idx)
#             buckets[bucket_index]["free"] -= num_nodes_to_add

#             num_nodes -= num_nodes_to_add
#             nodes_sampled_from_graph = nodes_sampled_from_graph[num_nodes_to_add:]

#     return buckets


# def distribute_objects_optimized(node_index_list, graph_sizes, num_buckets, batch_size, max_unique_graphs):
#     # node_index_list = list of node indices from each graph
#     # graph_sizes: list of graph sizes, same length/order as node_index_list

#     # Sort objects by size
#     sorted_objects = sorted(
#         zip(graph_sizes, np.arange(len(graph_sizes)), node_index_list),
#         key=lambda x: x[0],
#         reverse=True,
#     )

#     # Initialize buckets
#     buckets = [
#         {"graph_index": [], "node_index": [], "free": batch_size, "unique_graphs": set()}
#         for _ in range(num_buckets)
#     ]

#     dropped_nodes = 0  # Counter for dropped nodes due to graph limit

#     for graph_size, graph_idx, nodes_sampled_from_graph in sorted_objects:
#         num_nodes = len(nodes_sampled_from_graph)
#         bucket_index = 0
#         direction = 1  # 1 for left-to-right, -1 for right-to-left
#         successfully_added = False

#         while num_nodes > 0:
#             successfully_added = False
#             bucket_index = 0  # Reset to start for each graph to try all buckets

#             while not successfully_added:
#                 if buckets[bucket_index]["free"] > 0 and (
#                     len(buckets[bucket_index]["unique_graphs"]) < max_unique_graphs or
#                     graph_idx in buckets[bucket_index]["unique_graphs"]
#                 ):
#                     # Add the object to the bucket
#                     num_nodes_to_add = min(num_nodes, buckets[bucket_index]["free"])
#                     buckets[bucket_index]["graph_index"].extend([graph_idx] * num_nodes_to_add)
#                     buckets[bucket_index]["node_index"].extend(
#                         nodes_sampled_from_graph[:num_nodes_to_add]
#                     )
#                     buckets[bucket_index]["unique_graphs"].add(graph_idx)
#                     buckets[bucket_index]["free"] -= num_nodes_to_add

#                     num_nodes -= num_nodes_to_add
#                     nodes_sampled_from_graph = nodes_sampled_from_graph[num_nodes_to_add:]
#                     successfully_added = True

#                 bucket_index += direction
#                 if bucket_index >= num_buckets or bucket_index < 0:
#                     # Change direction when the end or start is reached
#                     direction *= -1
#                     bucket_index += direction

#             if not successfully_added:
#                 # All buckets are full or have reached the graph limit
#                 dropped_nodes += num_nodes
#                 break

#     print("Number of dropped nodes due to graph limit:", dropped_nodes)
#     return buckets


class UGTSnakeSampler(Sampler[T_co]):
    def __init__(
        self,
        dataset: Dataset,
        graph_ids: torch.Tensor,
        num_nodes: torch.Tensor,
        ratio_in_epoch: torch.Tensor,
        seq_len: torch.Tensor,
        max_unique_graphs: int,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
        batch_size: int = 1,
        seed: int = 0,
        accum_gradient_steps: int = 1,
        drop_last: bool = True,
    ) -> None:

        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1)
            )
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last
        self.max_unique_graphs = max_unique_graphs
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.

        # self.num_sample_bucket = self.num_replicas * batch_size * accum_gradient_steps
        self.shuffle = shuffle
        self.seed = seed
        self.batch_size = batch_size
        self.graph_ids = graph_ids
        self.num_nodes = num_nodes
        self.accum_gradient_steps = accum_gradient_steps
        self.num_buckets = self.num_replicas * self.accum_gradient_steps
        self.ratio_in_epoch = ratio_in_epoch
        self.seq_len = seq_len

        self.num_sample_bucket = self.num_replicas * accum_gradient_steps
        indices = self.sample_indices()
        dataset_len = len(indices)
        if self.drop_last and dataset_len % self.num_sample_bucket != 0:
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (dataset_len - self.num_sample_bucket) / self.num_sample_bucket
            )
        else:
            self.num_samples = math.ceil(dataset_len / self.num_sample_bucket)
        self.total_size = self.num_samples * self.num_sample_bucket
        print(f"Total size: {self.total_size}")
        print(f"Dataset size: {dataset_len}")

    # def sample_indices(self):

    #     g = torch.Generator()
    #     g.manual_seed(self.seed + self.epoch)

    #     indices = list(range(len(self.dataset)))
    #     start_indices = np.cumsum(self.seq_len) - self.seq_len

    #     # Reset sampled indices
    #     self.sampled_indices = []

    #     # print(f"Start indices: {start_indices}")
    #     # Sample indices for each sequence length based on ratio
    #     for start, length, ratio in zip(start_indices, self.seq_len, self.ratio_in_epoch):
    #         end = start + length
    #         num_samples = int(np.floor(ratio * length))
    #         sampled = np.random.choice(indices[start:end], num_samples, replace=False)
    #         self.sampled_indices.extend(sampled)

    #     # Optionally shuffle the collected indices
    #     np.random.shuffle(self.sampled_indices)

    #     return self.sampled_indices

    def sample_indices(self):
        # Initialize the generator with a seed based on the epoch and a fixed seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)

        # Create a list of indices
        indices = torch.arange(len(self.dataset))
        start_indices = torch.cumsum(torch.tensor(self.seq_len), dim=0) - self.seq_len

        # Reset sampled indices
        self.sampled_indices = []

        # Sample indices for each sequence length based on ratio
        for start, length, ratio in zip(
            start_indices, self.seq_len, self.ratio_in_epoch
        ):
            start = int(start)
            end = int(start + length)
            num_samples = int((ratio * length).floor())

            # oversampling
            if length > 0:
                if ratio > 1:
                    # For oversampling, repeat indices as needed to reach the desired number
                    full_batches = int((num_samples // length).item())
                    remainder = int((num_samples % length).item())
                    sampled_indices = (
                        torch.cat(
                            [
                                torch.randperm(end - start, generator=g)
                                for _ in range(full_batches)
                            ]
                            + [torch.randperm(end - start, generator=g)[:remainder]]
                        )
                        + start
                    )
                    self.sampled_indices.extend(sampled_indices.tolist())
                else:

                    sampled_indices = (
                        torch.randperm(end - start, generator=g)[:num_samples] + start
                    )
                    self.sampled_indices.extend(sampled_indices.tolist())
            else:
                # print("0 nodes present so passing the dataset")
                continue

        # Optionally shuffle the collected indices
        self.sampled_indices = torch.tensor(self.sampled_indices)
        self.sampled_indices = self.sampled_indices[
            torch.randperm(len(self.sampled_indices), generator=g)
        ].tolist()

        return self.sampled_indices

    def __iter__(self) -> Iterator[T_co]:

        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = self.sample_indices()
            # indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]

        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[
                    :padding_size
                ]
        else:
            # remove tail of data to make it evenly divisible.
            # indices = indices
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        graph_sizes_id = self.num_nodes[indices]
        # make the big batch
        batches_indices = chunk(indices, self.batch_size * self.num_buckets)
        batches_graph_id = chunk(graph_sizes_id, self.batch_size * self.num_buckets)

        for iter, (batch_graph_id, batch_index) in enumerate(
            zip(batches_graph_id, batches_indices)
        ):
            if len(batch_graph_id) < self.num_buckets * self.batch_size:
                batch_size = max(1, len(batch_graph_id) // self.num_buckets)
            else:
                batch_size = self.batch_size
            unique_graph_sizes = torch.unique(batch_graph_id)
            grouped_indices = []
            for graph_id in unique_graph_sizes:
                mask = batch_graph_id == graph_id
                grouped_indices.append(batch_index[mask])
            result = distribute_objects_optimized(
                grouped_indices,
                unique_graph_sizes,
                self.num_buckets,
                batch_size,
                self.max_unique_graphs,
            )

            for acc_step in range(self.accum_gradient_steps):
                if (
                    result[self.rank + acc_step * self.num_replicas]["node_index"]
                ) == 0:
                    print("no batch")
                yield result[self.rank + acc_step * self.num_replicas]["node_index"]

    def set_epoch(self, epoch: int) -> None:
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.
        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch
