

import copy
import os
import random
from typing import Any, Dict, Iterable, Optional, Union

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F


def enable_full_determinism(seed: int):
    """
    Helper function for reproducible behavior during distributed training. See
    - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
    """
    # set seed first
    set_seed(seed)

    #  Enable PyTorch deterministic mode. This potentially requires either the environment
    #  variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
    # depending on the CUDA version, so we set them both here
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
    torch.use_deterministic_algorithms(True)

    # Enable CUDNN deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def set_seed(seed: int):
    """
    Args:
    Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
        seed (`int`): The seed to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # ^^ safe to call this function even if cuda is not available


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMA:
    """
    Exponential Moving Average of models weights
    """

    def __init__(
        self,
        parameters: Iterable[torch.nn.Parameter],
        decay: float = 0.9999,
        min_decay: float = 0.0,
        update_after_step: int = 0,
        use_ema_warmup: bool = False,
        inv_gamma: Union[float, int] = 1.0,
        power: Union[float, int] = 2 / 3,
        model_cls: Optional[Any] = None,
        model_config: Dict[str, Any] = None,
        **kwargs,
    ):
        """
        Args:
            parameters (Iterable[torch.nn.Parameter]): The parameters to track.
            decay (float): The decay factor for the exponential moving average.
            min_decay (float): The minimum decay factor for the exponential moving average.
            update_after_step (int): The number of steps to wait before starting to update the EMA weights.
            use_ema_warmup (bool): Whether to use EMA warmup.
            inv_gamma (float):
                Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
            power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
            device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
                        weights will be stored on CPU.

        @crowsonkb's notes on EMA Warmup:
            If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
            to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
            gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
            at 215.4k steps).
        """

        parameters = list(parameters)
        self.shadow_params = [p.clone().detach() for p in parameters]

        self.temp_stored_params = None

        self.decay = decay
        self.min_decay = min_decay
        self.update_after_step = update_after_step
        self.use_ema_warmup = use_ema_warmup
        self.inv_gamma = inv_gamma
        self.power = power
        self.optimization_step = 0
        self.cur_decay_value = None  # set in `step()`

        self.model_cls = model_cls
        self.model_config = model_config

    @classmethod
    def from_pretrained(cls, path, model_cls) -> "EMA":
        _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
        model = model_cls.from_pretrained(path)

        ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)

        ema_model.load_state_dict(ema_kwargs)
        return ema_model

    def save_pretrained(self, path):
        if self.model_cls is None:
            raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")

        if self.model_config is None:
            raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")

        model = self.model_cls.from_config(self.model_config)
        state_dict = self.state_dict()
        state_dict.pop("shadow_params", None)

        model.register_to_config(**state_dict)
        self.copy_to(model.parameters())
        model.save_pretrained(path)

    def get_decay(self, optimization_step: int) -> float:
        """
        Compute the decay factor for the exponential moving average.
        """
        step = max(0, optimization_step - self.update_after_step - 1)

        if step <= 0:
            return 0.0

        if self.use_ema_warmup:
            cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
        else:
            cur_decay_value = (1 + step) / (10 + step)

        cur_decay_value = min(cur_decay_value, self.decay)
        # make sure decay is not smaller than min_decay
        cur_decay_value = max(cur_decay_value, self.min_decay)
        return cur_decay_value

    @torch.no_grad()
    def step(self, parameters: Iterable[torch.nn.Parameter]):
        parameters = list(parameters)

        self.optimization_step += 1

        # Compute the decay factor for the exponential moving average.
        decay = self.get_decay(self.optimization_step)
        self.cur_decay_value = decay
        one_minus_decay = 1 - decay

        for s_param, param in zip(self.shadow_params, parameters):
            if param.requires_grad:
                s_param.sub_(one_minus_decay * (s_param - param))
            else:
                s_param.copy_(param)

        torch.cuda.empty_cache()

    def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
        """
        Copy current averaged parameters into given collection of parameters.

        Args:
            parameters: Iterable of `torch.nn.Parameter`; the parameters to be
                updated with the stored moving averages. If `None`, the parameters with which this
                `ExponentialMovingAverage` was initialized will be used.
        """
        parameters = list(parameters)
        for s_param, param in zip(self.shadow_params, parameters):
            param.data.copy_(s_param.to(param.device).data)

    def to(self, device=None, dtype=None) -> None:
        r"""Move internal buffers of the ExponentialMovingAverage to `device`.

        Args:
            device: like `device` argument to `torch.Tensor.to`
        """
        # .to() on the tensors handles None correctly
        self.shadow_params = [
            p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
            for p in self.shadow_params
        ]

    def state_dict(self) -> dict:
        r"""
        Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
        checkpointing to save the ema state dict.
        """
        # Following PyTorch conventions, references to tensors are returned:
        # "returns a reference to the state and not its copy!" -
        # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
        return {
            "decay": self.decay,
            "min_decay": self.min_decay,
            "optimization_step": self.optimization_step,
            "update_after_step": self.update_after_step,
            "use_ema_warmup": self.use_ema_warmup,
            "inv_gamma": self.inv_gamma,
            "power": self.power,
            "shadow_params": self.shadow_params,
        }

    def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
        r"""
        Args:
        Save the current parameters for restoring later.
            parameters: Iterable of `torch.nn.Parameter`; the parameters to be
                temporarily stored.
        """
        self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]

    def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
        r"""
        Args:
        Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
        affecting the original optimization process. Store the parameters before the `copy_to()` method. After
        validation (or model saving), use this to restore the former parameters.
            parameters: Iterable of `torch.nn.Parameter`; the parameters to be
                updated with the stored parameters. If `None`, the parameters with which this
                `ExponentialMovingAverage` was initialized will be used.
        """
        if self.temp_stored_params is None:
            raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
        for c_param, param in zip(self.temp_stored_params, parameters):
            param.data.copy_(c_param.data)

        # Better memory-wise.
        self.temp_stored_params = None

    def load_state_dict(self, state_dict: dict) -> None:
        r"""
        Args:
        Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
        ema state dict.
            state_dict (dict): EMA state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        # deepcopy, to be consistent with module API
        state_dict = copy.deepcopy(state_dict)

        self.decay = state_dict.get("decay", self.decay)
        if self.decay < 0.0 or self.decay > 1.0:
            raise ValueError("Decay must be between 0 and 1")

        self.min_decay = state_dict.get("min_decay", self.min_decay)
        if not isinstance(self.min_decay, float):
            raise ValueError("Invalid min_decay")

        self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
        if not isinstance(self.optimization_step, int):
            raise ValueError("Invalid optimization_step")

        self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
        if not isinstance(self.update_after_step, int):
            raise ValueError("Invalid update_after_step")

        self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
        if not isinstance(self.use_ema_warmup, bool):
            raise ValueError("Invalid use_ema_warmup")

        self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
        if not isinstance(self.inv_gamma, (float, int)):
            raise ValueError("Invalid inv_gamma")

        self.power = state_dict.get("power", self.power)
        if not isinstance(self.power, (float, int)):
            raise ValueError("Invalid power")

        shadow_params = state_dict.get("shadow_params", None)
        if shadow_params is not None:
            self.shadow_params = shadow_params
            if not isinstance(self.shadow_params, list):
                raise ValueError("shadow_params must be a list")
            if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
                raise ValueError("shadow_params must all be Tensors")


# calculates entropy over each pixel distribution
def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
    # only calculated entropy over image tokens that were masked in the original image
    masked_tokens = input_ids == mask_id
    num_masked_pixels = masked_tokens.sum(-1)

    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)

    entropy_per_pixel = -((probs * log_probs).sum(-1))

    # the predictions for non-masked aren't used, so set their entropies to zero
    entropy_per_pixel[~masked_tokens] = 0

    entropy_per_image_numerator = entropy_per_pixel.sum(-1)
    entropy_per_image = entropy_per_image_numerator / num_masked_pixels

    total_buckets = 10
    masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)

    entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)

    return entropy_by_masked_bucket


# calculates entropy over the averaged distribution of pixels for the whole image
def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
    # only calculated entropy over image tokens that were masked in the original image
    masked_tokens = input_ids == mask_id
    num_masked_pixels = masked_tokens.sum(-1, keepdim=True)

    pixel_probs = F.softmax(logits, dim=-1)
    pixel_probs[~masked_tokens] = 0
    image_probs_numerator = pixel_probs.sum(-2)
    image_probs = image_probs_numerator / num_masked_pixels

    image_log_probs = image_probs.log()

    entropy_per_image = -((image_probs * image_log_probs).sum(-1))

    total_buckets = 10
    masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)

    entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)

    return entropy_by_masked_bucket


def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing):
    cross_entropy_per_image = F.cross_entropy(
        logits.view(-1, output_size),
        labels.view(-1),
        ignore_index=-100,
        label_smoothing=label_smoothing,
        reduction="none",
    )

    total_buckets = 10
    masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)

    cross_entropy_by_percent_masked_bucket = average_by_buckets(cross_entropy_per_image, masked_buckets, total_buckets)

    return cross_entropy_by_percent_masked_bucket


def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id):
    probs = F.softmax(logits, dim=-1)

    total_buckets = 10
    masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)

    data = []

    for bucket_idx in range(total_buckets):
        indices_for_bucket = masked_buckets[masked_buckets == bucket_idx]

        # It's ok if none were noised in the range of this bucket. This
        # function will be called for a later training step where it's likely
        # there will be an element noised in the range.
        if indices_for_bucket.shape[0] == 0:
            continue

        index_for_bucket = indices_for_bucket[0]

        image_probs = probs[index_for_bucket]

        # find the index of a masked pixel for the image
        input_ids_for_image = input_ids[index_for_bucket]
        masked_pixels_probs = image_probs[input_ids_for_image == mask_id]

        masked_pixel_probs = masked_pixels_probs[0]

        masked_pixel_probs = masked_pixel_probs.cpu().numpy()

        for masked_pixel_prob in masked_pixel_probs:
            data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob})

    df = pd.DataFrame(data)

    return df


def average_by_buckets(values, masked_buckets, total_buckets):
    unique_buckets, bucket_counts = masked_buckets.unique(dim=0, return_counts=True)

    numerator = torch.zeros(total_buckets, device=values.device)

    numerator.scatter_add_(0, masked_buckets, values)

    # default value is one because the buckets for which there aren't
    # any values will have a numerator of zero. So we just need to not divide
    # by zero.
    denominator = torch.ones(total_buckets, device=values.device, dtype=torch.long)
    denominator[unique_buckets] = bucket_counts

    averaged_by_buckets = numerator / denominator

    return averaged_by_buckets


def input_ids_to_masked_buckets(input_ids, mask_id, total_buckets=10):
    assert total_buckets == 10

    masked_percent = (input_ids == mask_id).sum(-1) / input_ids.shape[-1]

    # we do not formally use timesteps to noise images. Instead, we mask a percent
    # of the pixels. We don't want to log entropy for every mask percent between 0 and 1,
    # and we also want to track how the entropy evolves over time w/in a range of mask
    # percents that should have similar entropy. So we bucket the masked percents into a
    # fixed number of buckets

    # we could generalize this later if needed but for now, let's just assume a fixed
    # number of 10 buckets.

    # How this maps to a bucket index:
    # (mask) * bucket_index +
    # (mask_1) * bucket_index_1
    #
    # -> Where the mask is true will be set to the expected bucket index,
    # where the mask is false will be set to 0.
    #
    # Given the probabilities are between 0 and 1, each masked_percent will get mapped
    # to a timestep by one and only one of the masks.

    masked_buckets = (
        ((0 < masked_percent) & (masked_percent <= 0.1)) * 0
        + ((0.1 < masked_percent) & (masked_percent <= 0.2)) * 1
        + ((0.2 < masked_percent) & (masked_percent <= 0.3)) * 2
        + ((0.3 < masked_percent) & (masked_percent <= 0.4)) * 3
        + ((0.4 < masked_percent) & (masked_percent <= 0.5)) * 4
        + ((0.5 < masked_percent) & (masked_percent <= 0.6)) * 5
        + ((0.6 < masked_percent) & (masked_percent <= 0.7)) * 6
        + ((0.7 < masked_percent) & (masked_percent <= 0.8)) * 7
        + ((0.8 < masked_percent) & (masked_percent <= 0.9)) * 8
        + ((0.9 < masked_percent) & (masked_percent <= 1.0)) * 9
    )

    return masked_buckets
