#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensor utility."""

from itertools import chain

import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core


try:
    gpu_prop = paddle.device.cuda.get_device_properties()
    if "A100" in gpu_prop.name:
        TENSOR_CORE_MULTI = 1
    elif "V100" in gpu_prop.name:
        TENSOR_CORE_MULTI = 8
    else:
        TENSOR_CORE_MULTI = 1
except:
    print("You can upgarde PaddlePaddle >= 2.2.0 for better AMP performance.")
    TENSOR_CORE_MULTI = 1


def to_optimized_size(sz):
    """Padding sequence to speedup matmul OP.

    According to the tensor cores documentation from NVIDIA, the matmul OP in fp16 must all be multiples of
    TENSOR_CORE_MULTI to get the speedup from fp16.
    See more detail:
    https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
    """
    return (sz + TENSOR_CORE_MULTI - 1) // TENSOR_CORE_MULTI * TENSOR_CORE_MULTI


def get_tensor(tensor_name, to_np=True):
    """Get tensor by name."""
    var = fluid.global_scope().find_var(tensor_name)
    if var is None:
        return None
    tensor = var.get_tensor()
    if tensor is None:
        return None
    if to_np:
        return np.array(tensor)
    else:
        return tensor


def to_lodtensor(data, place, dtype=None):
    """Convert data to LoDTensor."""
    if place is None:
        return data
    lengths = []
    while isinstance(data[0], list):
        lengths.append(list(map(len, data)))
        data = [x for xs in data for x in xs]
    if dtype is None:
        if isinstance(data[0], float):
            dtype = "float32"
        else:
            dtype = "int64"
    data = np.array(data, dtype=dtype)
    data_tensor = fluid.LoDTensor()
    data_tensor.set(data, place)
    data_tensor.set_recursive_sequence_lengths(lengths)
    return data_tensor


def pad_batch_data(insts, pad_id=0):
    """Pad the instances to the max sequence length in batch. """
    max_len = to_optimized_size(max(map(len, insts)))
    inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
    return inst_data.astype("int64").reshape([-1, max_len, 1])


def convert_lodtensor_to_list(tensor):
    data = np.array(tensor)
    recursive_sequence_lengths = tensor.recursive_sequence_lengths()
    recursive_sequence_lengths.reverse()
    for i, lengths in enumerate(recursive_sequence_lengths):
        offset = 0
        new_data = []
        for j, l in enumerate(lengths):
            new_data.append(list(data[offset:offset + l]))
            offset += l
        data = new_data
    return data


def concatenate_lodtensors(tensors, place):
    """Concatenate LoD tensors."""
    data = []
    recursive_sequence_lengths = []
    for tensor in tensors:
        data.append(np.array(tensor))
        recursive_sequence_lengths.append(tensor.recursive_sequence_lengths())
    data = np.concatenate(data, axis=0)
    recursive_sequence_lengths = [sum(lens, []) for lens in zip(*recursive_sequence_lengths)]
    data_tensor = fluid.LoDTensor()
    data_tensor.set(data, place)
    data_tensor.set_recursive_sequence_lengths(recursive_sequence_lengths)
    assert data_tensor.has_valid_recursive_sequence_lengths()
    return data_tensor


def repeat_array_or_tensor(array_or_tensor, place, times):
    """Repeate numpy array or LoD tensor."""
    if isinstance(array_or_tensor, fluid.LoDTensor):
        data = [np.array(array_or_tensor)] * times
        recursive_sequence_lengths = [array_or_tensor.recursive_sequence_lengths()] * times
        data = np.concatenate(data, axis=0)
        recursive_sequence_lengths = [sum(lens, []) for lens in zip(*recursive_sequence_lengths)]
        data_tensor = fluid.LoDTensor()
        data_tensor.set(data, place)
        data_tensor.set_recursive_sequence_lengths(recursive_sequence_lengths)
        assert data_tensor.has_valid_recursive_sequence_lengths()
        return data_tensor
    elif isinstance(array_or_tensor, list):
        return list(chain(*([array_or_tensor] * times)))
    else:
        return np.concatenate([array_or_tensor] * times, axis=0)


def slice_array_or_tensor(array_or_tensor, place, begin, end):
    """Repeate numpy array or LoD tensor."""
    if isinstance(array_or_tensor, fluid.LoDTensor):
        data = convert_lodtensor_to_list(array_or_tensor)
        data = data[begin:end]
        if array_or_tensor._dtype() == core.VarDesc.VarType.FP32:
            dtype="float32"
        else:
            dtype="int64"
        return to_lodtensor(data, place, dtype)
    else:
        return array_or_tensor[begin:end]
