from collections import OrderedDict

import numpy as np

from .aspect import get_closest_ratio, get_resolution_with_aspect_ratio
from .utils import map_target_fps


class Bucket:
    def __init__(self, bucket_config: dict[str, dict[int, tuple[float, int] | tuple[tuple[float, float], int]]]):
        """
        Args:
            bucket_config (dict): A dictionary containing the bucket configuration.
                The dictionary should be in the following format:
                {
                    "bucket_name": {
                        "time": (probability, batch_size),
                        "time": (probability, batch_size),
                        ...
                    },
                    ...
                }

                Or in the following format:
                {
                    "bucket_name": {
                        "time": ((probability, next_probability), batch_size),
                        "time": ((probability, next_probability), batch_size),
                        ...
                    },
                    ...
                }
                The bucket_name should be the name of the bucket, and the time should be the number of frames in the video.
                The probability should be a float between 0 and 1, and the batch_size should be an integer.
                If the probability is a tuple, the second value should be the probability to skip to the next time.
        """

        aspect_ratios = {key: get_resolution_with_aspect_ratio(key) for key in bucket_config.keys()}
        bucket_probs = OrderedDict()
        bucket_bs = OrderedDict()
        bucket_names = sorted(bucket_config.keys(), key=lambda x: aspect_ratios[x][0], reverse=True)

        for key in bucket_names:
            bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
            bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
            bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})

        self.hw_criteria = {k: aspect_ratios[k][0] for k in bucket_names}
        self.t_criteria = {k1: {k2: k2 for k2 in bucket_config[k1].keys()} for k1 in bucket_names}
        self.ar_criteria = {
            k1: {k2: {k3: v3 for k3, v3 in aspect_ratios[k1][1].items()} for k2 in bucket_config[k1].keys()}
            for k1 in bucket_names
        }

        bucket_id_cnt = num_bucket = 0
        bucket_id = dict()
        for k1, v1 in bucket_probs.items():
            bucket_id[k1] = dict()
            for k2, _ in v1.items():
                bucket_id[k1][k2] = bucket_id_cnt
                bucket_id_cnt += 1
                num_bucket += len(aspect_ratios[k1][1])

        self.bucket_probs = bucket_probs
        self.bucket_bs = bucket_bs
        self.bucket_id = bucket_id
        self.num_bucket = num_bucket

        print("Number of buckets: %s", num_bucket)

    def get_bucket_id(
        self,
        T: int,
        H: int,
        W: int,
        fps: float,
        path: str | None = None,
        seed: int | None = None,
        fps_max: int = 16,
    ) -> tuple[str, int, int] | None:
        approx = 0.8
        _, sampling_interval = map_target_fps(fps, fps_max)
        T = T // sampling_interval
        resolution = H * W
        rng = np.random.default_rng(seed)

        # Reference to probabilities and criteria for faster access
        bucket_probs = self.bucket_probs
        hw_criteria = self.hw_criteria
        ar_criteria = self.ar_criteria

        # Start searching for the appropriate bucket
        for hw_id, t_criteria in bucket_probs.items():
            # if resolution is too low, skip
            if resolution < hw_criteria[hw_id] * approx:
                continue

            # if sample is an image
            if T == 1:
                if 1 in t_criteria:
                    if rng.random() < t_criteria[1]:
                        return hw_id, 1, get_closest_ratio(H, W, ar_criteria[hw_id][1])
                continue

            # Look for suitable t_id for video
            for t_id, prob in t_criteria.items():
                if T >= t_id and t_id != 1:
                    # if prob is a tuple, use the second value as the threshold to skip
                    # to the next t_id
                    if isinstance(prob, tuple):
                        next_hw_prob, next_t_prob = prob
                        if next_t_prob >= 1 or rng.random() <= next_t_prob:
                            continue
                    else:
                        next_hw_prob = prob
                    if next_hw_prob >= 1 or rng.random() <= next_hw_prob:
                        ar_id = get_closest_ratio(H, W, ar_criteria[hw_id][t_id])
                        return hw_id, t_id, ar_id
                    else:
                        break

        return None

    def get_thw(self, bucket_idx: tuple[str, int, int]) -> tuple[int, int, int]:
        assert len(bucket_idx) == 3
        T = self.t_criteria[bucket_idx[0]][bucket_idx[1]]
        H, W = self.ar_criteria[bucket_idx[0]][bucket_idx[1]][bucket_idx[2]]
        return T, H, W

    def get_prob(self, bucket_idx: tuple[str, int]) -> float:
        return self.bucket_probs[bucket_idx[0]][bucket_idx[1]]

    def get_batch_size(self, bucket_idx: tuple[str, int]) -> int:
        return self.bucket_bs[bucket_idx[0]][bucket_idx[1]]

    def __len__(self) -> int:
        return self.num_bucket
