# coding=utf-8

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from fairscale.fair_dev.testing.testing import dist_init, spawn_for_all_world_sizes
from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import random
from fairscale.nn.model_parallel.random import get_cuda_rng_tracker, model_parallel_cuda_manual_seed


def run_test_set_cuda_rng_state(rank, model_parallel_size, filename, filename_rpc):
    dist_init(rank, model_parallel_size, filename, filename_rpc)

    if torch.distributed.get_rank() == 0:
        print("> testing set_rng_state with size {} ...".format(model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    size = 123
    seed = 1234
    torch.cuda.manual_seed(1234)
    tensor = torch.cuda.FloatTensor(size)

    # Get the state
    rng_state = torch.cuda.get_rng_state()
    rng_state_copy = rng_state.clone()

    # Do some stuff.
    for _ in range(5):
        torch.randn(size, out=tensor)
    result_1 = tensor.clone()

    assert rng_state.sub(rng_state_copy).max() == 0
    assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0

    # State should be different.
    new_rng_state = torch.cuda.get_rng_state()
    max_diff = new_rng_state.sub(rng_state).max()
    print(
        "   max diff in rng state (should be non-zero) on global rank {}: {}".format(
            torch.distributed.get_rank(), max_diff
        )
    )
    assert max_diff > 0

    # Reset the rng state and do the same stuff.
    random._set_cuda_rng_state(rng_state)
    for _ in range(5):
        torch.randn(size, out=tensor)
    random._set_cuda_rng_state(rng_state)
    for _ in range(5):
        torch.randn(size, out=tensor)
    result_2 = tensor.clone()

    # Results should be the same
    error = result_2.sub(result_1).abs().max()
    print(
        "   max error in generated tensors (should be zero) on global rank {}: {}".format(
            torch.distributed.get_rank(), error
        )
    )
    assert error < 1.0e-6

    # Input state should have remained intact.
    error = rng_state.sub(rng_state_copy).max()
    print(
        "   max error in rng state (should be zero) on global rank {}: {}".format(torch.distributed.get_rank(), error)
    )
    assert error == 0

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(">> passed the test :-)")


def run_test_cuda_rng_tracker(rank, model_parallel_size, filename, filename_rpc):
    dist_init(rank, model_parallel_size, filename, filename_rpc)

    if torch.distributed.get_rank() == 0:
        print("> testing cuda rng tracker with size {} ...".format(model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    seed_1 = 1234
    seed_2 = 4321
    size = [12, 21]
    tensor = torch.cuda.FloatTensor(size)

    # Set to seed_1 and generate two tensors.
    torch.cuda.manual_seed(seed_1)
    torch.randn(size, out=tensor)
    target_11 = tensor.clone()
    torch.randn(size, out=tensor)
    target_12 = tensor.clone()

    # Set to seed_2 and generate two tensors.
    torch.cuda.manual_seed(seed_2)
    torch.randn(size, out=tensor)
    target_21 = tensor.clone()
    torch.randn(size, out=tensor)
    target_22 = tensor.clone()

    # Now if we interleave seed_1 and seed_2,
    # we should still get the same tensors
    torch.cuda.manual_seed(seed_1)
    get_cuda_rng_tracker().add("test", seed_2)

    torch.randn(size, out=tensor)
    result_11 = tensor.clone()

    with get_cuda_rng_tracker().fork("test"):
        torch.randn(size, out=tensor)
        result_21 = tensor.clone()

    torch.randn(size, out=tensor)
    result_12 = tensor.clone()

    with get_cuda_rng_tracker().fork("test"):
        torch.randn(size, out=tensor)
        result_22 = tensor.clone()

    diff = result_11.sub(result_21).abs().max()
    diff = min(diff, result_12.sub(result_22).abs().max())
    print(
        "   max diff in generated tensors (should be non-zero) on global rank {}: {}".format(
            torch.distributed.get_rank(), diff
        )
    )
    assert diff > 1.0e-6
    error = max(result_11.sub(target_11).abs().max(), result_12.sub(target_12).abs().max())
    error = max(error, result_21.sub(target_21).abs().max())
    error = max(error, result_22.sub(target_22).abs().max())
    print(
        "   max error in generated tensors (should be zero) on global rank {}: {}".format(
            torch.distributed.get_rank(), error
        )
    )
    assert error < 1.0e-6

    # Reset the tracker
    get_cuda_rng_tracker().reset()

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(">> passed the test :-)")


def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size, filename, filename_rpc):
    dist_init(rank, model_parallel_size, filename, filename_rpc)

    if torch.distributed.get_rank() == 0:
        print("> testing model parallel cuda manual seed with size {} ...".format(model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    model_parallel_cuda_manual_seed(12345)
    assert torch.cuda.initial_seed() == 12345
    with get_cuda_rng_tracker().fork():
        assert torch.cuda.initial_seed() == (12345 + 2718 + mpu.get_model_parallel_rank())

    # Reset the tracker
    get_cuda_rng_tracker().reset()

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(">> passed the test :-)")


def test_set_cuda_rng_state():
    spawn_for_all_world_sizes(run_test_set_cuda_rng_state)


def test_cuda_rng_tracker():
    spawn_for_all_world_sizes(run_test_cuda_rng_tracker)


def test_model_parallel_cuda_manual_seed():
    spawn_for_all_world_sizes(run_test_model_parallel_cuda_manual_seed)
