# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import random

import pytest
import torch

from xformers.factory.model_factory import xFormer, xFormerConfig

BATCH = 2
SEQ = 64
EMB = 48
VOCAB = 16
DEVICES = (
    [torch.device("cpu")]
    if not torch.cuda.is_available()
    else [torch.device("cuda")]  # save a bit on CI, we have seperate cpu and gpu jobs
)

_test_config_encoder = {
    "reversible": False,
    "block_type": "encoder",
    "dim_model": EMB,
    "position_encoding_config": {
        "name": "vocab",
        "seq_len": SEQ,
        "vocab_size": VOCAB,
        "dim_model": EMB,
    },
    "num_layers": 3,
    "multi_head_config": {
        "num_heads": 4,
        "residual_dropout": 0,
        "attention": {
            "name": "linformer",
            "dropout": 0,
            "causal": True,
            "seq_len": SEQ,
        },
        "dim_model": EMB,
    },
    "feedforward_config": {
        "name": "MLP",
        "dropout": 0,
        "activation": "relu",
        "hidden_layer_multiplier": 4,
        "dim_model": EMB,
    },
}

_test_config_decoder = {
    "block_type": "decoder",
    "dim_model": EMB,
    "position_encoding_config": {
        "name": "vocab",
        "seq_len": SEQ,
        "vocab_size": VOCAB,
        "dim_model": EMB,
    },
    "num_layers": 2,
    "multi_head_config_masked": {
        "num_heads": 4,
        "residual_dropout": 0,
        "dim_model": EMB,
        "attention": {
            "name": "linformer",
            "dropout": 0,
            "causal": True,
            "seq_len": SEQ,
        },
    },
    "multi_head_config_cross": {
        "num_heads": 4,
        "residual_dropout": 0,
        "dim_model": EMB,
        "attention": {
            "name": "linformer",
            "dropout": 0,
            "causal": True,
            "seq_len": SEQ,
        },
    },
    "feedforward_config": {
        "name": "MLP",
        "dropout": 0,
        "activation": "relu",
        "hidden_layer_multiplier": 4,
        "dim_model": EMB,
    },
}

# Test a pure encoder, a pure decoder, an encoder/decoder stack
_test_configs = [
    [_test_config_encoder, _test_config_decoder],
    [_test_config_encoder],
]


def _rev_config(config, flag: bool):
    for c in filter(
        lambda x: x["block_type"] == "encoder",
        config,
    ):
        c["reversible"] = flag

    return config


@pytest.mark.parametrize("config", _test_configs)
@pytest.mark.parametrize("device", DEVICES)
def test_reversible_runs(config, device):

    # Build both a reversible and non-reversible model
    model_non_reversible = xFormer.from_config(
        xFormerConfig(_rev_config(config, False))
    ).to(device)
    model_reversible = xFormer.from_config(xFormerConfig(_rev_config(config, True))).to(
        device
    )

    # Dummy inputs, test a forward
    inputs = (torch.rand((BATCH, SEQ), device=device) * 10).abs().to(torch.int)
    _ = model_non_reversible(inputs)
    _ = model_reversible(inputs)


@pytest.mark.parametrize("device", DEVICES)
def test_reversible_no_alternate(device):

    # Check that we cannot build a non-coherent stack
    with pytest.raises(AssertionError):
        rev = dict(_test_config_encoder)  # we need to make a copy
        rev["reversible"] = True
        non_rev = dict(_test_config_encoder)
        non_rev["reversible"] = False

        _ = xFormer.from_config(xFormerConfig([rev, non_rev])).to(device)


@pytest.mark.parametrize("config", _test_configs)
@pytest.mark.parametrize("device", DEVICES)
def test_reversible_train(config, device):
    torch.manual_seed(0)
    random.seed(0)

    # Dummy inputs, test some training to make sure that we both can approximate the same thing to some extent
    # This is not super scientific, more of a foolproof catch
    def data():
        input_a = torch.zeros((BATCH, SEQ), device=device).to(torch.int)
        input_b = (torch.rand((BATCH, SEQ), device=device) * VOCAB).abs().to(torch.int)

        target_a = torch.zeros((BATCH, SEQ), device=device)
        target_b = torch.ones((BATCH, SEQ), device=device)

        if random.random() > 0.5:
            return torch.cat([input_a, input_b], dim=0), torch.cat(
                [target_a, target_b], dim=0
            )

        return torch.cat([input_b, input_a], dim=0), torch.cat(
            [target_b, target_a], dim=0
        )

    def step(model: torch.nn.Module, optim: torch.optim.Optimizer):
        batch, target = data()
        model.train()
        optim.zero_grad()

        outputs = model(batch)
        loss = torch.norm(torch.mean(outputs, dim=-1) - target)
        loss.backward()

        # Clip grad and error out if we're producing NaNs, part of the unit test
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), 10.0, norm_type=2.0, error_if_nonfinite=True
        )
        optim.step()

        return loss.item()

    def evaluate(model: torch.nn.Module):
        batch, target = data()
        model.eval()
        outputs = model(batch)
        return torch.norm(torch.mean(outputs, dim=-1) - target).item()

    # Build both a reversible and non-reversible model
    model_non_reversible = xFormer.from_config(
        xFormerConfig(_rev_config(config, False))
    ).to(device)
    model_reversible = xFormer.from_config(xFormerConfig(_rev_config(config, True))).to(
        device
    )

    optim_rev = torch.optim.SGD(model_reversible.parameters(), lr=1e-3, momentum=0.9)
    optim_non_rev = torch.optim.SGD(
        model_non_reversible.parameters(), lr=1e-3, momentum=0.9
    )

    # Check that both models can be trained to comparable results
    eval_start_rev = evaluate(model_reversible)
    eval_start_non_rev = evaluate(model_non_reversible)

    for i in range(100):
        print(i, " reversible: ", step(model_reversible, optim_rev))
        print(i, " non reversible: ", step(model_non_reversible, optim_non_rev))

    # Check that we can classify this dummy example
    # Arbitrary threshold
    eval_stop_rev = evaluate(model_reversible)
    eval_stop_non_rev = evaluate(model_non_reversible)
    if len(config) < 2:  # only check the encoder case
        train_ratio_rev = eval_start_rev / eval_stop_rev
        train_ratio_non_rev = eval_start_non_rev / eval_stop_non_rev

        # Assert that train ratio > 1 (we trained),
        # and reversible is not much worse than non-reversible (it's actually better on this dummy test)

        assert train_ratio_rev > 1
        assert train_ratio_non_rev > 1
        assert train_ratio_rev > train_ratio_non_rev
