from typing import Optional

import torch
import pandas as pd
import numpy as np

from utils.representations.representation_tracker import (
    RepresentationTracker,
    _concatenate_batches,
    _split_samples,
)


def _mock_get_random_subset(
    t: torch.Tensor, n_samples: Optional[int]
) -> torch.Tensor:
    return t if n_samples is None else t[:n_samples]


def test_concatenate_batches():
    t1 = torch.Tensor([[1, 2], [3, 4]])
    t2 = torch.Tensor([[0, 0]])
    combination = _concatenate_batches([t1, t2])

    expected = torch.Tensor([[1, 2], [3, 4], [0, 0]])
    assert torch.all(combination == expected)

def test_split_samples():
    samples = torch.zeros(size=(32, 10, 8), dtype=torch.float32)
    t1, t2 = _split_samples(samples, _mock_get_random_subset)
    assert torch.all(t1 == torch.zeros(size=(16, 10, 8), dtype=torch.float32))
    assert torch.all(t2 == torch.zeros(size=(16, 10, 8), dtype=torch.float32))


def test_inter_group_tracking():
    rng = torch.Generator()
    tracker = RepresentationTracker(rng, 2)
    tracker.get_random_subset = _mock_get_random_subset

    g1 = torch.Tensor([[0, 0], [1, 1]])
    tracker.add_rep_samples("g1", {"l1": g1})
    g2 = torch.Tensor([[1, 1], [1, 1]])
    tracker.add_rep_samples("g2", {"l1": g2})
    mean_dist = tracker.compute_mean_dist("inter_group", "l2")

    index = pd.Index(["l1"], name="layer")
    # sqrt(2) dist between the first pair and 0 between the second
    # -> mean = sqrt(2) / 2
    expected = pd.DataFrame({"g1-g2": [np.sqrt(2) / 2]}, index=index)
    assert np.isclose(mean_dist, expected, atol=1e-4)

def test_inter_group_two_layers_tracking():
    rng = torch.Generator()
    tracker = RepresentationTracker(rng, 2)
    tracker.get_random_subset = _mock_get_random_subset

    g1 = torch.Tensor([[0, 0], [1, 1]])
    tracker.add_rep_samples("g1", {"l1": g1, "l2": g1})
    g2 = torch.Tensor([[1, 1], [1, 1]])
    tracker.add_rep_samples("g2", {"l1": g2, "l2": g2})
    mean_dist = tracker.compute_mean_dist("inter_group", "l2")

    index = pd.Index(["l1", "l2"], name="layer")
    # sqrt(2) dist between the first pair and 0 between the second
    # -> mean = sqrt(2) / 2
    expected = pd.DataFrame(
        {"g1-g2": [np.sqrt(2) / 2, np.sqrt(2) / 2]},
        index=index,
    )
    assert np.all(np.isclose(mean_dist.values, expected.values, atol=1e-4))

def test_across_all_one_group_tracking():
    rng = torch.Generator()
    tracker = RepresentationTracker(rng, 4)
    tracker.get_random_subset = _mock_get_random_subset

    g1 = torch.Tensor([[0, 0], [1, 1], [1, 1], [1, 1]])
    tracker.add_rep_samples("g1", {"l1": g1})
    mean_dist = tracker.compute_mean_dist("across_all", "l2")

    index = pd.Index(["l1"], name="layer")
    expected = pd.DataFrame({"across_all": [np.sqrt(2) / 2]}, index=index)
    assert np.isclose(mean_dist, expected, atol=1e-4)

def test_across_all_two_groups_tracking():
    rng = torch.Generator()
    tracker = RepresentationTracker(rng, 4)
    tracker.get_random_subset = _mock_get_random_subset

    g1 = torch.Tensor([[0, 0], [1, 1], [1, 1], [1, 1]])
    tracker.add_rep_samples("g1", {"l1": g1})
    g2 = torch.Tensor([[2, 2], [1, 1], [1, 1], [1, 1]])
    tracker.add_rep_samples("g2", {"l1": g2})
    mean_dist = tracker.compute_mean_dist("across_all", "l2")

    index = pd.Index(["l1"], name="layer")
    expected = pd.DataFrame({"across_all": [np.sqrt(2) / 2]}, index=index)
    assert np.isclose(mean_dist, expected, atol=1e-4)

def test_intra_group_tracking_cka_same():
    rng = torch.Generator()
    tracker = RepresentationTracker(rng, 2)
    tracker.get_random_subset = _mock_get_random_subset

    g1 = torch.Tensor([[0, 0], [1, 1]])
    tracker.add_rep_samples("g1", {"l1": g1})
    g2 = torch.Tensor([[0, 0], [1, 1]])
    tracker.add_rep_samples("g2", {"l1": g2})
    mean_dist = tracker.compute_mean_dist("inter_group", "linear_cka")

    index = pd.Index(["l1"], name="layer")
    expected = pd.DataFrame({"g1-g2": [1]}, index=index)
    assert np.isclose(mean_dist, expected, atol=1e-4)
