from collections import OrderedDict, defaultdict
from typing import Iterator

from accelerate import Accelerator
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DistributedSampler

from .aspect import get_num_pexels_from_name
from .bucket import Bucket
from .parallel import pandarallel
from .utils import sync_object_across_devices

from ..train.data_word import T3DataSetWarp

def format_numel_str(numel: int) -> str:
    """
    Format a number of elements to a human-readable string.

    Args:
        numel (int): The number of elements.

    Returns:
        str: The formatted string.
    """
    B = 1024**3
    M = 1024**2
    K = 1024
    if numel >= B:
        return f"{numel / B:.2f} B"
    elif numel >= M:
        return f"{numel / M:.2f} M"
    elif numel >= K:
        return f"{numel / K:.2f} K"
    else:
        return f"{numel}"


# use pandarallel to accelerate bucket processing
# NOTE: pandarallel should only access local variables
def apply(data, method=None, seed=None, num_bucket=None, fps_max=16):
    return method(
        data["num_frames"],
        data["height"],
        data["width"],
        data["fps"],
        data["path"],
        seed + data["id"] * num_bucket,
        fps_max,
    )


class StatefulDistributedSampler(DistributedSampler):
    def __init__(
        self,
        dataset: Dataset,
        num_replicas: int | None = None,
        rank: int | None = None,
        shuffle: bool = True,
        seed: int = 0,
        drop_last: bool = False,
    ) -> None:
        super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
        self.start_index: int = 0

    def __iter__(self) -> Iterator:
        iterator = super().__iter__()
        indices = list(iterator)
        indices = indices[self.start_index :]
        return iter(indices)

    def __len__(self) -> int:
        return self.num_samples - self.start_index

    def reset(self) -> None:
        self.start_index = 0

    def state_dict(self, step) -> dict:
        return {"start_index": step}

    def load_state_dict(self, state_dict: dict) -> None:
        self.__dict__.update(state_dict)


class VariableVideoBatchSampler(DistributedSampler):
    def __init__(
        self,
        dataset: T3DataSetWarp,
        bucket_config: dict,
        num_replicas: int | None = None,
        rank: int | None = None,
        shuffle: bool = True,
        seed: int = 0,
        drop_last: bool = False,
        verbose: bool = False,
        num_bucket_build_workers: int = 1,
        num_groups: int = 1,
        accelerator: Accelerator = None,
    ) -> None:
        super().__init__(
            dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
        )
        self.dataset = dataset
        assert dataset.bucket_class == "Bucket", "Only support Bucket class for now"
        self.bucket = Bucket(bucket_config)
        self.verbose = verbose
        self.last_micro_batch_access_index = 0
        self.num_bucket_build_workers = num_bucket_build_workers
        self._cached_bucket_sample_dict = None
        self._cached_num_total_batch = None
        self.num_groups = num_groups

        if rank == 0:
            pandarallel.initialize(
                nb_workers=self.num_bucket_build_workers,
                progress_bar=False,
                verbose=0,
                use_memory_fs=False,
            )
        self.rank = rank
        self.accelerator = accelerator

    def __iter__(self) -> Iterator[list[int]]:
        bucket_sample_dict, _ = self.group_by_bucket()
        self.clear_cache()

        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        bucket_micro_batch_count = OrderedDict()
        bucket_last_consumed = OrderedDict()

        # process the samples
        for bucket_id, data_list in bucket_sample_dict.items():
            # handle droplast
            bs_per_gpu = self.bucket.get_batch_size(bucket_id)
            remainder = len(data_list) % bs_per_gpu

            if remainder > 0:
                if not self.drop_last:
                    # if there is remainder, we pad to make it divisible
                    data_list += data_list[: bs_per_gpu - remainder]
                else:
                    # we just drop the remainder to make it divisible
                    data_list = data_list[:-remainder]
            bucket_sample_dict[bucket_id] = data_list

            # handle shuffle
            if self.shuffle:
                data_indices = torch.randperm(len(data_list), generator=g).tolist()
                data_list = [data_list[i] for i in data_indices]
                bucket_sample_dict[bucket_id] = data_list

            # compute how many micro-batches each bucket has
            num_micro_batches = len(data_list) // bs_per_gpu
            bucket_micro_batch_count[bucket_id] = num_micro_batches

        # compute the bucket access order
        # each bucket may have more than one batch of data
        # thus bucket_id may appear more than 1 time
        bucket_id_access_order = []
        for bucket_id, num_micro_batch in bucket_micro_batch_count.items():
            bucket_id_access_order.extend([bucket_id] * num_micro_batch)

        # randomize the access order
        if self.shuffle:
            bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist()
            bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices]

        # make the number of bucket accesses divisible by dp size
        remainder = len(bucket_id_access_order) % self.num_replicas
        if remainder > 0:
            if self.drop_last:
                bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder]
            else:
                bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder]

        # prepare each batch from its bucket
        # according to the predefined bucket access order
        num_iters = len(bucket_id_access_order) // self.num_replicas
        start_iter_idx = self.last_micro_batch_access_index // self.num_replicas

        # re-compute the micro-batch consumption
        # this is useful when resuming from a state dict with a different number of GPUs
        self.last_micro_batch_access_index = start_iter_idx * self.num_replicas
        for i in range(self.last_micro_batch_access_index):
            bucket_id = bucket_id_access_order[i]
            bucket_bs = self.bucket.get_batch_size(bucket_id)
            if bucket_id in bucket_last_consumed:
                bucket_last_consumed[bucket_id] += bucket_bs
            else:
                bucket_last_consumed[bucket_id] = bucket_bs

        for i in range(start_iter_idx, num_iters):
            bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas]
            self.last_micro_batch_access_index += self.num_replicas

            # compute the data samples consumed by each access
            bucket_access_boundaries = []
            for bucket_id in bucket_access_list:
                bucket_bs = self.bucket.get_batch_size(bucket_id)
                last_consumed_index = bucket_last_consumed.get(bucket_id, 0)
                bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs])

                # update consumption
                if bucket_id in bucket_last_consumed:
                    bucket_last_consumed[bucket_id] += bucket_bs
                else:
                    bucket_last_consumed[bucket_id] = bucket_bs

            # compute the range of data accessed by each GPU
            bucket_id = bucket_access_list[self.rank]
            boundary = bucket_access_boundaries[self.rank]
            cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]]

            # encode t, h, w into the sample index
            real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
            cur_micro_batch = [f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch]
            # print(cur_micro_batch)
            yield cur_micro_batch

        self.reset()

    def __len__(self) -> int:
        return self.get_num_batch() // self.num_groups

    def get_num_batch(self) -> int:
        _, num_total_batch = self.group_by_bucket()
        return num_total_batch

    def clear_cache(self):
        self._cached_bucket_sample_dict = None
        self._cached_num_total_batch = 0

    def group_by_bucket(self) -> dict:
        """
        Group the dataset samples into buckets.
        This method will set `self._cached_bucket_sample_dict` to the bucket sample dict.

        Returns:
            dict: a dictionary with bucket id as key and a list of sample indices as value
        """
        if self._cached_bucket_sample_dict is not None:
            return self._cached_bucket_sample_dict, self._cached_num_total_batch

        # use pandarallel to accelerate bucket processing
        print("Building buckets using %d workers...", self.num_bucket_build_workers)
        bucket_ids = None
        if dist.get_rank() == 0:
        # if self.accelerator.is_main_process:
            data = self.dataset.data.copy(deep=True)
            data["id"] = data.index
            bucket_ids = data.parallel_apply(
                apply,
                axis=1,
                method=self.bucket.get_bucket_id,
                seed=self.seed + self.epoch,
                num_bucket=self.bucket.num_bucket,
                fps_max=self.dataset.fps_max,
            )
        dist.barrier()
        # self.accelerator.wait_for_everyone()
        bucket_ids = sync_object_across_devices(bucket_ids)
        dist.barrier()
        # self.accelerator.wait_for_everyone()

        # group by bucket
        # each data sample is put into a bucket with a similar image/video size
        bucket_sample_dict = defaultdict(list)
        bucket_ids_np = np.array(bucket_ids)
        valid_indices = np.where(bucket_ids_np != None)[0]
        for i in valid_indices:
            bucket_sample_dict[bucket_ids_np[i]].append(i)

        # cache the bucket sample dict
        self._cached_bucket_sample_dict = bucket_sample_dict

        # num total batch
        num_total_batch = self.print_bucket_info(bucket_sample_dict)
        self._cached_num_total_batch = num_total_batch

        return bucket_sample_dict, num_total_batch

    def print_bucket_info(self, bucket_sample_dict: dict) -> int:
        # collect statistics
        num_total_samples = num_total_batch = 0
        num_total_img_samples = num_total_vid_samples = 0
        num_total_img_batch = num_total_vid_batch = 0
        num_total_vid_batch_256 = num_total_vid_batch_768 = 0
        num_aspect_dict = defaultdict(lambda: [0, 0])
        num_hwt_dict = defaultdict(lambda: [0, 0])
        for k, v in bucket_sample_dict.items():
            size = len(v)
            num_batch = size // self.bucket.get_batch_size(k[:-1])

            num_total_samples += size
            num_total_batch += num_batch

            if k[1] == 1:
                num_total_img_samples += size
                num_total_img_batch += num_batch
            else:
                if k[0] == "256px":
                    num_total_vid_batch_256 += num_batch
                elif k[0] == "768px":
                    num_total_vid_batch_768 += num_batch
                num_total_vid_samples += size
                num_total_vid_batch += num_batch

            num_aspect_dict[k[-1]][0] += size
            num_aspect_dict[k[-1]][1] += num_batch
            num_hwt_dict[k[:-1]][0] += size
            num_hwt_dict[k[:-1]][1] += num_batch

        # sort
        num_aspect_dict = dict(sorted(num_aspect_dict.items(), key=lambda x: x[0]))
        num_hwt_dict = dict(
            sorted(num_hwt_dict.items(), key=lambda x: (get_num_pexels_from_name(x[0][0]), x[0][1]), reverse=True)
        )
        num_hwt_img_dict = {k: v for k, v in num_hwt_dict.items() if k[1] == 1}
        num_hwt_vid_dict = {k: v for k, v in num_hwt_dict.items() if k[1] > 1}

        # log
        if dist.get_rank() == 0 and self.verbose:
        # if self.accelerator.is_main_process and self.verbose:
            print("Bucket Info:")
            print("Bucket [#sample, #batch] by aspect ratio:")
            for k, v in num_aspect_dict.items():
                print("(%s): #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
            print("===== Image Info =====")
            print("Image Bucket by HxWxT:")
            for k, v in num_hwt_img_dict.items():
                print("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
            print("--------------------------------")
            print(
                "#image sample: %s, #image batch: %s",
                format_numel_str(num_total_img_samples),
                format_numel_str(num_total_img_batch),
            )
            print("===== Video Info =====")
            print("Video Bucket by HxWxT:")
            for k, v in num_hwt_vid_dict.items():
                print("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
            print("--------------------------------")
            print(
                "#video sample: %s, #video batch: %s",
                format_numel_str(num_total_vid_samples),
                format_numel_str(num_total_vid_batch),
            )
            print("===== Summary =====")
            print("#non-empty buckets: %s", len(bucket_sample_dict))
            print(
                "Img/Vid sample ratio: %.2f",
                num_total_img_samples / num_total_vid_samples if num_total_vid_samples > 0 else 0,
            )
            print(
                "Img/Vid batch ratio: %.2f", num_total_img_batch / num_total_vid_batch if num_total_vid_batch > 0 else 0
            )
            print(
                "vid batch 256: %s, vid batch 768: %s", format_numel_str(num_total_vid_batch_256), format_numel_str(num_total_vid_batch_768)
            )
            print(
                "Vid batch ratio (256px/768px): %.2f", num_total_vid_batch_256 / num_total_vid_batch_768 if num_total_vid_batch_768 > 0 else 0
            )
            print(
                "#training sample: %s, #training batch: %s",
                format_numel_str(num_total_samples),
                format_numel_str(num_total_batch),
            )
        return num_total_batch

    def reset(self):
        self.last_micro_batch_access_index = 0

    def set_step(self, start_step: int):
        self.last_micro_batch_access_index = start_step * self.num_replicas

    def state_dict(self, num_steps: int) -> dict:
        # the last_micro_batch_access_index in the __iter__ is often
        # not accurate during multi-workers and data prefetching
        # thus, we need the user to pass the actual steps which have been executed
        # to calculate the correct last_micro_batch_access_index
        return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas}

    def load_state_dict(self, state_dict: dict) -> None:
        self.__dict__.update(state_dict)


class BatchDistributedSampler(DistributedSampler):
    """
    Used with BatchDataset;
    Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
           | buffer {i}          | buffer {i+1}
    ------ | ------------------- | -------------------
    rank 0 |  0,  1,  2,  3,  4, |  5,  6,  7,  8,  9
    rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
    rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
    """

    def __init__(self, dataset: Dataset, **kwargs):
        super().__init__(dataset, **kwargs)
        self.start_index = 0

    def __iter__(self):
        num_buffers = self.dataset.num_buffers
        len_buffer = self.dataset.len_buffer
        num_buffers_i = num_buffers // self.num_replicas
        num_samples_i = len_buffer * num_buffers_i

        indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i
        indices_i = indices_i.tolist()

        return iter(indices_i)

    def reset(self):
        self.start_index = 0

    def state_dict(self, step) -> dict:
        return {"start_index": step}

    def load_state_dict(self, state_dict: dict):
        self.start_index = state_dict["start_index"] + 1
