import unittest
from argparse import Namespace
from functools import partial
from typing import Tuple

import numpy as np
import torch
from data.get import get_dataset
from torch.nn import functional as F
from torch.utils.data import DataLoader

from mahalanobis.models import (LowRankCovEncoder,
                                batched_sherman_morrison_rank_one_inverse,
                                proto_ddu_cnn, proto_ddu_linear,
                                proto_mahalanobis_cnn,
                                proto_mahalanobis_linear, proto_sngp_cnn,
                                proto_sngp_linear, protonet_cnn,
                                protonet_linear)

T = torch.Tensor


def get_moons() -> Tuple[Namespace, DataLoader, DataLoader]:
    args = Namespace(seed=0, batch_size=2, dataset="few-shot-toy-moons", n_way=2, k_shot=10)
    train, _, test = get_dataset(args)
    return args, train, test


def get_conv() -> Tuple[Namespace, DataLoader, DataLoader]:
    args = Namespace(
        seed=0,
        batch_size=10,
        dataset="omniglot",
        n_way=5,
        k_shot=5,
        data_root="/home/datasets",
        train_query_shots=5,
        val_query_shots=5,
        num_workers=1,
        split="train",
        ood_test=False,
        run=0, corrupt_test=False
    )

    train, _, test = get_dataset(args)
    return args, train, test


class TestSetTransformer(unittest.TestCase):
    def test_smoketest_low_rank_cov_encoder(self) -> None:
        encoder = LowRankCovEncoder(dim_input=128, rank=3, dim_output=128)

        inputs = torch.randn(5 * 15, 128)
        centroids = torch.randn(5, 128)
        precision, logdet = encoder(inputs, None, centroids=centroids, n_way=5, k_shot=15, lambd=0.01)
        self.assertListEqual(list(precision.size()), [5, 128, 128])
        self.assertListEqual(list(logdet.size()), [5])
        self.assertFalse(torch.all(torch.isnan(logdet)))
        self.assertFalse(torch.all(torch.isinf(logdet)))


class TestMahalanobis(unittest.TestCase):
    def test_batched_sherman_morrisson_inverse(self) -> None:
        for i in range(10):
            dim = np.random.randint(2, 32)
            batch = np.random.randint(2, 32)
            Adiag = np.random.rand() * torch.ones(dim)

            rank = np.random.randint(1, 10)
            B_factors = torch.randn(batch, rank, dim)

            inverse, logdet = batched_sherman_morrison_rank_one_inverse(Adiag.view(1, dim).repeat(batch, 1), B_factors)
            matrix = torch.diag(Adiag).unsqueeze(0).repeat(batch, 1, 1) + torch.bmm(B_factors.transpose(1, 2), B_factors)

            logdet_expected = torch.logdet(matrix)
            result = torch.bmm(matrix, inverse)

            result[result < 1e-4] = 0
            expected = torch.eye(dim).unsqueeze(0).repeat(batch, 1, 1)

            self.assertTrue(torch.abs(logdet.squeeze(-1) - logdet_expected).sum().item() < 1e-2)
            self.assertTrue((result - expected).sum().item() < 0.05)

    def test_smoketest_mahalanobis(self) -> None:
        dims = ((1, 1, 28, 28), (1, 64, 14, 14), (1, 64, 7, 7), (1, 64, 4, 4))
        for net_func, get_data in zip(
            [
                partial(proto_mahalanobis_linear, n_layers=6, in_dim=2, h_dim=64, classes=2, p=0.01, ctype="none"),
                partial(proto_ddu_linear, n_layers=6, in_dim=2, h_dim=64, classes=2, p=0.01, ctype="none"),
                partial(protonet_linear, n_layers=6, in_dim=2, h_dim=64, classes=2, p=0.01, ctype="none"),
                partial(proto_sngp_linear, n_layers=6, in_dim=2, h_dim=64, classes=2, p=0.01, ctype="none"),
                partial(proto_mahalanobis_cnn, dims=dims, in_ch=1, h_dim=64, classes=5, p=0.01, ctype="none"),
                partial(proto_ddu_cnn, dims=dims, in_ch=1, h_dim=64, classes=5, p=0.01, ctype="none"),
                partial(protonet_cnn, dims=dims, in_ch=1, h_dim=64, classes=5, p=0.01, ctype="none"),
                partial(proto_sngp_cnn, dims=dims, in_ch=1, h_dim=64, classes=5, p=0.01, ctype="none")
            ],
            [get_moons, get_moons, get_moons, get_moons, get_conv, get_conv, get_conv, get_conv]
        ):
            model = net_func()
            opt = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
            args, train, test = get_data()

            model.train()
            for i, (x_tr, y_tr, x_te, y_te) in enumerate(train):
                for (xs, ys, xq, yq) in zip(x_tr, y_tr, x_te, y_te):
                    opt.zero_grad()
                    logits = model(xs, ys, xq, n_way=args.n_way, k_shot=args.k_shot)
                    loss = -F.log_softmax(logits, dim=-1)[torch.arange(yq.size(0)), yq].mean()

                    loss.backward()
                    opt.step()

                if i == 5:
                    break

            model.eval()
            with torch.no_grad():
                for i, (x_tr, y_tr, x_te, y_te) in enumerate(train):
                    for (xs, ys, xq, yq) in zip(x_tr, y_tr, x_te, y_te):
                        aleatoric, epistemic, energy = model.inference(xs, ys, xq, n_way=args.n_way, k_shot=args.k_shot)
                        self.assertEqual(aleatoric.size(1), args.n_way)
                        self.assertEqual(epistemic.size(1), args.n_way)
                        self.assertEqual(len(list(energy.size())), 1)
                        self.assertListEqual([xq.size(0) for _ in range(3)], [aleatoric.size(0), epistemic.size(0), energy.size(0)])

                    if i == 5:
                        break
