# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
import contextlib
import numbers


def str_hash(text: str):
    hash = 0
    for ch in text:
        hash = (hash * 281 ^ ord(ch) * 997) & 0xFFFFFFFF
    return hash


@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds, key=None):
    """Context manager which seeds the NumPy PRNG with the specified seed and
    restores the state afterward"""
    if seed is None:
        yield
        return

    def check_seed(s):
        assert type(s) == int or type(s) == np.int32 or type(s) == np.int64

    check_seed(seed)
    if len(addl_seeds) > 0:
        for s in addl_seeds:
            check_seed(s)
        seed = int(hash((seed, *addl_seeds)) % 1e8)
    if key is not None:
        seed = int(hash((seed, str_hash(key))) % 1e8)
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)


def convert_to_single_emb(x, sizes):
    assert x.shape[-1] == len(sizes)
    offset = 1
    for i in range(len(sizes)):
        assert (x[..., i] < sizes[i]).all()
        x[..., i] = x[..., i] + offset
        offset += sizes[i]
    return x


def pad_flatten(samples):
    return torch.cat(samples)


def pad_flatten_dim1(samples):
    return torch.cat(samples, dim=1)


def pad_1d(samples, pad_len, pad_value=0, pad_type=None):
    batch_size = len(samples)
    tensor = torch.full([batch_size, pad_len], pad_value, dtype=samples[0].dtype)
    for i in range(batch_size):
        tensor[i, : samples[i].shape[0]] = samples[i]
        if pad_type is not None and pad_type == "first":
            tensor[i, samples[i].shape[0] :] = samples[i][0]
    return tensor


def pad_1d_feat(samples, pad_len, pad_value=0, pad_type=None):
    batch_size = len(samples)
    assert len(samples[0].shape) == 2
    feat_size = samples[0].shape[-1]
    tensor = torch.full(
        [batch_size, pad_len, feat_size], pad_value, dtype=samples[0].dtype
    )
    for i in range(batch_size):
        tensor[i, : samples[i].shape[0]] = samples[i]
        if pad_type is not None and pad_type == "first":
            tensor[i, samples[i].shape[0] :] = samples[i][0]
    return tensor


def pad_2d(samples, pad_len, pad_value=0):
    batch_size = len(samples)
    tensor = torch.full(
        [batch_size, pad_len, pad_len], pad_value, dtype=samples[0].dtype
    )
    for i in range(batch_size):
        tensor[i, : samples[i].shape[0], : samples[i].shape[1]] = samples[i]
    return tensor


def pad_2d_feat(samples, pad_len, pad_value=0):
    batch_size = len(samples)
    assert len(samples[0].shape) == 3
    feat_size = samples[0].shape[-1]
    tensor = torch.full(
        [batch_size, pad_len, pad_len, feat_size], pad_value, dtype=samples[0].dtype
    )
    for i in range(batch_size):
        tensor[i, : samples[i].shape[0], : samples[i].shape[1]] = samples[i]
    return tensor


def pad_attn_bias(samples, pad_len):
    batch_size = len(samples)
    pad_len = pad_len + 1
    tensor = torch.full(
        [batch_size, pad_len, pad_len], float("-inf"), dtype=samples[0].dtype
    )
    for i in range(batch_size):
        tensor[i, : samples[i].shape[0], : samples[i].shape[1]] = samples[i]
        tensor[i, samples[i].shape[0] :, : samples[i].shape[1]] = 0
    return tensor


def pad_3d(samples, pad_len_x, pad_len_y, pad_len_z, pad_value=0):
    batch_size = len(samples)
    tensor = torch.full(
        [batch_size, pad_len_x, pad_len_y, pad_len_z], pad_value, dtype=samples[0].dtype
    )
    for i in range(batch_size):
        tensor[
            i, : samples[i].shape[0], : samples[i].shape[1], : samples[i].shape[2]
        ] = samples[i]
    return tensor


def pad_3d_feat(samples, pad_len_x, pad_len_y, pad_len_z, pad_value=0):
    batch_size = len(samples)
    assert len(samples[0].shape) == 4
    feat_size = samples[0].shape[-1]
    if isinstance(pad_value, numbers.Number):
        tensor = torch.full(
            [batch_size, pad_len_x, pad_len_y, pad_len_z, feat_size],
            pad_value,
            dtype=samples[0].dtype,
        )
    else:
        tensor = torch.full(
            [batch_size, pad_len_x, pad_len_y, pad_len_z, feat_size],
            0,
            dtype=samples[0].dtype,
        )
        tensor[:, :, :, :] = pad_value
    for i in range(batch_size):
        tensor[
            i, : samples[i].shape[0], : samples[i].shape[1], : samples[i].shape[2]
        ] = samples[i]
    return tensor
