import unittest

import numpy as np  # type: ignore
import torch
from sklearn.datasets import load_iris  # type: ignore
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset

from deep_deterministic_uncertainty.model import DDU, DDU_WideResNet28_cifar


class TestDDU(unittest.TestCase):
    def test_smoketest_ddu(self) -> None:
        x, y = torch.randn(32, 3, 32, 32).cuda(), torch.randint(0, 10, (32,)).cuda()

        model = DDU_WideResNet28_cifar(resnet_kwargs=dict(num_classes=10, widen_factor=2), ddu_kwargs=dict(num_classes=10)).cuda()
        opt = torch.optim.SGD(model.parameters(), lr=4e-2, momentum=0.9, weight_decay=5e-4)

        model.train()
        for i in range(50):
            opt.zero_grad()
            logits = model(x)
            self.assertEqual(logits.size(0), 32)
            self.assertEqual(logits.size(1), 10)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            opt.step()

        model.eval()
        with torch.no_grad():
            self.assertEqual(model.centroids.sum(), 0)
            model.update_centroids(x, y)
            model.compute_centroids()
            self.assertNotEqual(model.centroids.sum(), 0)

            self.assertEqual(model.cov.sum(), 0)
            # setting the value of the covariance to I for testing since there is not enough training
            # to make a proper covariance matrix
            model.cov = torch.stack([torch.eye(model.h_dim).cuda() for _ in range(model.classes)])
            model.prec = torch.stack([torch.eye(model.h_dim).cuda() for _ in range(model.classes)])
            model.cov_logdets = torch.ones(model.classes, 1).cuda()
            logits, log_px, m_log_px_mask = model.inference(x.cuda())
            loss = F.cross_entropy(logits, y)

    def test_ddu_covariance_against_sklearn(self) -> None:
        dataset = load_iris()
        x, y = dataset["data"], dataset["target"]
        perm = np.random.permutation(x.shape[0])
        x, y = x[perm], y[perm]

        train_n = int(x.shape[0] * 0.9)
        xtr, ytr, xte = x[:train_n], y[:train_n], x[train_n:]

        qda = QDA(store_covariance=True)
        qda.fit(xtr, ytr)

        qda_log_p_xy = qda.predict_log_proba(xte)

        module = nn.Sequential()
        module.name = "dummy"  # type: ignore

        ddu = DDU(module, 3, xtr.shape[1])
        xtr, ytr = torch.from_numpy(xtr).float(), torch.from_numpy(ytr).float()
        xte = torch.from_numpy(xte).float()

        ddu.update_centroids(xtr, ytr)
        ddu.compute_centroids()

        ddu.update_covariance(xtr, ytr)
        ddu.compute_covariance()
        ddu.invert_covariance()

        ddu_log_pxy = ddu.log_px(xte)
        self.assertTrue(torch.all(torch.isclose(torch.from_numpy(qda.means_).float(), ddu.centroids)))

        for i in range(3):
            self.assertTrue(torch.all(torch.isclose(torch.from_numpy(qda.covariance_[i]).float(), ddu.cov[i])))

        self.assertTrue((torch.from_numpy(qda_log_p_xy).softmax(dim=-1) - ddu_log_pxy.softmax(dim=-1)).sum() < 1e-5)

    def test_ddu_tune_and_inference(self) -> None:
        dataset = load_iris()
        x, y = dataset["data"], dataset["target"]
        perm = np.random.permutation(x.shape[0])
        x, y = torch.from_numpy(x[perm]).float(), torch.from_numpy(y[perm]).float()

        train_n = int(x.shape[0] * 0.1)
        xtr, ytr, xte, yte = x[:train_n], y[:train_n], x[train_n:], y[train_n:]

        train_dataset = TensorDataset(xtr, ytr)
        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=10)
        test_dataset = TensorDataset(xte, yte)
        test_loader = DataLoader(test_dataset, shuffle=True, batch_size=10)

        module = nn.Identity()
        module.name = "dummy"  # type: ignore
        ddu = DDU(module, 3, xtr.shape[1])

        ddu.update_centroids(xtr, ytr)
        ddu.compute_centroids()
        ddu.update_covariance(xtr, ytr)
        ddu.compute_covariance()
        ddu.invert_covariance()

        ddu.tune(train_loader)
        for i, (x, y) in enumerate(test_loader):
            logits, epictemic, ind = ddu.inference(x)
