# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import pytest
import torch
from torch import nn

from fairscale.nn.pipe.balance import balance_by_size, balance_by_time, blockpartition
from fairscale.nn.pipe.balance.profile import layerwise_sandbox

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")

devices = ["cpu"]
if torch.cuda.is_available():
    devices.append("cuda")


def test_blockpartition():
    assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [[1, 2, 3, 4], [5, 6]]


def test_blockpartition_zeros():
    assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]]


def test_blockpartition_non_positive_partitions():
    with pytest.raises(ValueError):
        blockpartition.solve([42], partitions=0)
    with pytest.raises(ValueError):
        blockpartition.solve([42], partitions=-1)


def test_blockpartition_short_sequence():
    with pytest.raises(ValueError):
        blockpartition.solve([], partitions=1)
    with pytest.raises(ValueError):
        blockpartition.solve([42], partitions=2)


@pytest.mark.parametrize("device", devices)
def test_balance_by_time(device):
    class Delay(nn.Module):
        def __init__(self, seconds):
            super().__init__()
            self.seconds = seconds

        def forward(self, x):
            time.sleep(self.seconds)
            return x

    model = nn.Sequential(*[Delay(i / 100) for i in [1, 2, 3, 4, 5, 6]])
    sample = torch.rand(1)
    balance = balance_by_time(2, model, sample, device=device)
    assert balance == [4, 2]


def test_balance_by_time_loop_resets_input():
    # nn.Flatten was introduced at PyTorch 1.2.0.
    class Flatten(nn.Module):
        def forward(self, x):
            return x.flatten(1)

    model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10))
    sample = torch.rand(10, 3, 8, 8)
    balance = balance_by_time(2, model, sample, device="cpu")
    assert balance == [1, 2]


@skip_if_no_cuda
def test_balance_by_size_latent():
    class Expand(nn.Module):
        def __init__(self, times):
            super().__init__()
            self.times = times

        def forward(self, x):
            for i in range(self.times):
                x = x + torch.rand_like(x, requires_grad=True)
            return x

    sample = torch.rand(10, 100, 100)

    model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]])
    balance = balance_by_size(2, model, sample)
    assert balance == [4, 2]

    model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]])
    balance = balance_by_size(2, model, sample)
    assert balance == [2, 4]


@skip_if_no_cuda
def test_balance_by_size_param():
    model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)])
    sample = torch.rand(7, 1)
    balance = balance_by_size(2, model, sample, param_scale=100)
    assert balance == [4, 2]

    model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))])
    sample = torch.rand(1, 7)
    balance = balance_by_size(2, model, sample, param_scale=100)
    assert balance == [2, 4]


@skip_if_no_cuda
def test_balance_by_size_param_scale():
    class Tradeoff(nn.Module):
        def __init__(self, param_size, latent_size):
            super().__init__()
            self.fc = nn.Linear(param_size, param_size)
            self.latent_size = latent_size

        def forward(self, x):
            for i in range(self.latent_size):
                x = x + torch.rand_like(x, requires_grad=True)
            return x

    model = nn.Sequential(
        Tradeoff(param_size=1, latent_size=6),
        Tradeoff(param_size=2, latent_size=5),
        Tradeoff(param_size=3, latent_size=4),
        Tradeoff(param_size=4, latent_size=3),
        Tradeoff(param_size=5, latent_size=2),
        Tradeoff(param_size=6, latent_size=1),
    )

    sample = torch.rand(1, requires_grad=True)

    balance = balance_by_size(2, model, sample, param_scale=0)
    assert balance == [2, 4]

    balance = balance_by_size(2, model, sample, param_scale=100)
    assert balance == [4, 2]


@pytest.mark.parametrize("device", devices)
def test_layerwise_sandbox(device):
    model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
    model.eval()

    for layer in layerwise_sandbox(model, torch.device(device)):
        assert layer.training
        assert all(p.device.type == device for p in layer.parameters())

    assert all(not l.training for l in model)
    assert all(p.device.type == "cpu" for p in model.parameters())


@pytest.mark.parametrize("device", devices)
def test_sandbox_during_profiling(device):
    model = nn.Sequential(nn.BatchNorm2d(3))

    before = {k: v.clone() for k, v in model.state_dict().items()}

    sample = torch.rand(1, 3, 10, 10)
    balance_by_time(1, model, sample, device=device)

    after = model.state_dict()

    assert before.keys() == after.keys()
    for key, value in before.items():
        assert torch.allclose(after[key], value), key


def test_not_training():
    class AssertTraining(nn.Module):
        def forward(self, x):
            assert self.training
            return x

    model = nn.Sequential(AssertTraining())

    model.eval()
    assert not model.training

    sample = torch.rand(1)
    balance_by_time(1, model, sample, device="cpu")

    assert not model.training


def test_balance_by_time_tuple():
    class Twin(nn.Module):
        def forward(self, x):
            return x, x.detach()

    class Add(nn.Module):
        def forward(self, a_b):
            a, b = a_b
            return a + b

    model = nn.Sequential(Twin(), Add())
    sample = torch.rand(1, requires_grad=True)
    balance_by_time(1, model, sample, device="cpu")


@skip_if_no_cuda
def test_balance_by_size_tuple():
    class Twin(nn.Module):
        def forward(self, x):
            return x, x.detach()

    class Add(nn.Module):
        def forward(self, a_b):
            a, b = a_b
            return a + b

    model = nn.Sequential(Twin(), Add())
    sample = torch.rand(1, requires_grad=True)
    balance_by_size(1, model, sample)


def test_already_has_grad():
    model = nn.Sequential(nn.Conv2d(3, 3, 1))
    sample = torch.rand(1, 3, 32, 32)
    model(sample).norm().backward()

    with pytest.raises(ValueError, match="some parameter already has gradient"):
        balance_by_time(1, model, sample, device="cpu")
