import torch

from rtrank.losses import (
    BidirectionalRtRankLoss,
    BTLoss,
    DeterministicRTLoss,
    InitialNormalizedLoss,
)
from rtrank.metric_tracker import MetricTracker
from rtrank.synthetic_data import HyperbolicUdiffRtRel


def test_bt_loss():
    metric_tracker = MetricTracker()
    bt_loss = BTLoss()
    bt_loss.set_metric_tracker(metric_tracker)

    # Test basic functionality
    u1 = torch.tensor([1.0, 1.0, 2.0, 0.0])
    u2 = torch.tensor([1.0, 0.0, 1.0, 2.0])
    y = torch.tensor([1, 0, 1, 0])
    bt_loss(u1, u2, y)
    metric_tracker.finalize_epoch()
    metrics = metric_tracker.get_epoch_metrics()
    assert isinstance(metrics, list), (
        f"Expected metrics to be a list, but got {type(metrics)}"
    )
    assert len(metrics) > 0, "Metrics list is empty"

    # Test extreme cases
    u1 = torch.tensor([100.0, -100.0])
    u2 = torch.tensor([0.0, 0.0])
    y = torch.tensor([1, 0])
    bt_loss(u1, u2, y)
    metric_tracker.finalize_epoch()
    metrics = metric_tracker.get_epoch_metrics()
    bt_loss(u1, u2, y)
    metric_tracker.finalize_epoch()

    # Test shift invariance
    u1 = torch.tensor([1.0, 2.0, 3.0])
    u2 = torch.tensor([0.0, 1.0, 2.0])
    y = torch.tensor([1, 0, 1])
    loss1 = bt_loss(u1, u2, y)

    shift = 10.0
    u1_shifted = u1 + shift
    u2_shifted = u2 + shift
    loss2 = bt_loss(u1_shifted, u2_shifted, y)

    assert torch.allclose(loss1, loss2, atol=1e-7)

    # Symmetry
    u1 = torch.tensor([1.0])
    u2 = torch.tensor([0.0])
    loss1 = bt_loss(u1, u2, torch.tensor([1]))
    loss2 = bt_loss(u2, u1, torch.tensor([0]))
    assert torch.allclose(loss1, loss2)

    # Check that loss increases when prediction gets worse
    y = torch.tensor([1])
    u1_good = torch.tensor([1.0])
    u2 = torch.tensor([0.0])
    u1_bad = torch.tensor([0.0])
    loss_good = bt_loss(u1_good, u2, y)
    loss_bad = bt_loss(u1_bad, u2, y)
    assert loss_bad > loss_good


def test_initial_normalized_loss():
    class DummyLoss:
        def __call__(self, x):
            return x

        def get_metrics(self):
            return {}

    # Create an instance of InitialNormalizedLoss
    normalized_loss = InitialNormalizedLoss(DummyLoss())

    # First call should set the normalization factor
    result1 = normalized_loss(torch.tensor([2.0]))
    expected1 = torch.tensor([1.0])
    assert torch.allclose(result1, expected1)

    # Subsequent calls should use the same normalization factor
    result2 = normalized_loss(torch.tensor([4.0]))
    assert torch.allclose(result2, torch.tensor([2.0]))

    # Test with batched input
    normalized_loss = InitialNormalizedLoss(DummyLoss())
    result1 = normalized_loss(torch.tensor([2.0, 4.0]))
    # Mean of the input is 3.0, so the normalization factor should be 3.0
    expected1 = torch.tensor([2.0, 4.0]) / 3.0
    assert torch.allclose(result1, expected1)


def test_deterministic_rt_loss():
    # Test basic functionality
    det_rt_loss = DeterministicRTLoss(assumed_udiff_rt_rel=HyperbolicUdiffRtRel(0, 10))
    u1 = torch.tensor([1.0, 2.0, 3.0])
    u2 = torch.tensor([0.0, 1.0, 2.0])
    y = torch.tensor([1, 1, 0])
    t = torch.tensor([0.6, 0.7, 0.8])

    losses = det_rt_loss(u1, u2, y, t)

    assert losses.shape == (3,)
    assert torch.all(losses >= 0)


def test_rt_ranking_loss():
    metric_tracker = MetricTracker()
    rt_ranking_loss = BidirectionalRtRankLoss()
    rt_ranking_loss.set_metric_tracker(metric_tracker)

    # Simple input
    u1s = torch.tensor([1.0, 2.0, 3.0])
    u2s = torch.tensor([0.5, 1.5, 2.5])
    ys = torch.tensor([1, 1, 1])
    rts = torch.tensor([0.6, 0.8, 1.0])

    loss = rt_ranking_loss(u1s, u2s, ys, rts)
    assert isinstance(loss, torch.Tensor)
    assert loss.ndim == 0

    # Equal utilities
    u1s = torch.tensor([1.0, 1.0, 1.0])
    u2s = torch.tensor([1.0, 1.0, 1.0])
    ys = torch.tensor([1, 1, 1])
    rts = torch.tensor([0.6, 0.8, 1.0])

    loss = rt_ranking_loss(u1s, u2s, ys, rts)
    assert isinstance(loss, torch.Tensor)
    assert loss.ndim == 0

    # Very different utilities
    u1s = torch.tensor([1.0, 10.0, 100.0, 1000.0, 10000.0, 100000.0])
    u2s = torch.tensor([0.1, 1.0, 10.0, 100.0, 1000.0, 10000.0])
    ys = torch.tensor([1, 1, 1, 1, 1, 1])
    rts = torch.tensor([0.6, 0.8, 1.0, 1.2, 1.4, 1.6])

    loss = rt_ranking_loss(u1s, u2s, ys, rts)
    assert isinstance(loss, torch.Tensor)
    assert loss.ndim == 0

    # Loss for correct ranking (higher response times -> smaller udiffs) is lower than for reversed
    udiffs = torch.tensor([1.0, 2.0, 3.0])
    udiffs_reversed = torch.flip(udiffs, dims=[-1])
    rts = udiffs_reversed
    u1s = torch.tensor([0.0, 0.0, 0.0])
    ys = torch.tensor([0, 0, 0])

    loss_correct = rt_ranking_loss(u1s, u1s + udiffs, ys, rts)
    metric_tracker.finalize_epoch()
    correct_metrics = metric_tracker.get_epoch_metrics()
    loss_reversed = rt_ranking_loss(u1s, u1s + udiffs_reversed, ys, rts)
    metric_tracker.finalize_epoch()
    reversed_metrics = metric_tracker.get_epoch_metrics()

    assert torch.all(loss_correct < loss_reversed), (
        f"{loss_correct} >= {loss_reversed};\nCorrect metrics\n{correct_metrics}\n\nReversed metrics\n{reversed_metrics}"
    )

    # Test unreduce
    unreduce_rt_ranking_loss = BidirectionalRtRankLoss(unreduce=True)
    u1s = torch.tensor([1.0, 2.0, 3.0])
    u2s = torch.tensor([0.5, 1.5, 2.5])
    ys = torch.tensor([1, 1, 1])
    rts = torch.tensor([0.6, 0.8, 1.0])

    loss = unreduce_rt_ranking_loss(u1s, u2s, ys, rts)
    assert isinstance(loss, torch.Tensor)
    assert loss.ndim == 1
    assert loss.shape == rts.shape
