import unittest

import numpy as np  # type: ignore
import torch
from torch import nn

from utils import (ece, ece_partial, ece_partial_final, flatten, full_hessian,
                   softmax_log_softmax_of_sample, unflatten_like)

T = torch.Tensor


class TestOther(unittest.TestCase):
    def test_softmax_log_softmax_of_sample(self) -> None:
        for i in range(100):
            n_samples, n, d = np.random.randint(100, 1000), np.random.randint(32, 128), np.random.randint(10, 100)
            mu = torch.randn(n, d)
            sigma = torch.rand(n, d)
            samples = torch.distributions.Normal(mu, sigma).sample((n_samples,))

            exp_sm = samples.softmax(dim=-1).mean(dim=0)
            exp_lsm = exp_sm.log()
            self.assertTrue(torch.all(exp_sm.sum(dim=-1)))

            sm, lsm = softmax_log_softmax_of_sample(samples)

            self.assertTrue(torch.abs(exp_sm - sm).sum() < 1e-4)
            self.assertTrue(torch.abs(exp_lsm - lsm).sum() < 1e-2)


class TestECE(unittest.TestCase):
    def test_partial_equals_regular_one_pass(self) -> None:
        logits = torch.rand(1000, 10)
        y = torch.randint(0, 10, (1000,))

        cal_error, _, _ = ece(y, logits)
        c, a, n_in_bins, n = ece_partial(y , logits)
        cal_error2 = ece_partial_final(c, a, n_in_bins, n)
        self.assertEqual(cal_error, cal_error2)

    def test_partial_equals_regular_partial_pass(self) -> None:
        # test equality for multple passes
        for i in range(10):
            logits = torch.rand(1000, 100)
            y = torch.randint(0, 100, (1000,))
            cal_error, _, _ = ece(y, logits)

            conf, acc, n_in_bins, n = torch.zeros(15), torch.zeros(15), torch.zeros(15), 0

            sections = torch.randperm(999)[:10].tolist()
            sections = [i for i in sections if i != 0]
            sections.sort()

            sections = [0] + sections + [1000]
            for i, sec in enumerate(sections[:-1]):
                y_ = y[sec : sections[i + 1]]
                logits_ = logits[sec : sections[i + 1]]
                c, a, nin, _n = ece_partial(y_, logits_)

                conf += c
                acc += a
                n_in_bins += nin
                n += _n

            cal_error2 = ece_partial_final(conf, acc, n_in_bins, n)
            self.assertAlmostEqual(cal_error, cal_error2, 3)


class TestFlatten(unittest.TestCase):
    def setUp(self) -> None:
        self.net = nn.Sequential(nn.Linear(2, 4), nn.ReLU(inplace=True), nn.Linear(4, 1))
        self.n_params = sum([p.numel() for p in self.net.parameters()])

    def test_flatten(self) -> None:
        """
        should flatten a list of tensors (parameters) into a flat list containing
        all the parameters in a model
        """
        out = flatten(self.net.parameters())
        self.assertEqual(out.size(0), self.n_params)

    def test_unflatten_like(self) -> None:
        """
        should unflatten a flattened list of parameters in the same way it was
        flattened. every parameter should be equal to the unflatenned case
        """
        flat = flatten(self.net.parameters())
        out = unflatten_like(flat, self.net.parameters())
        for (p1, p2) in zip(out, self.net.parameters()):
            self.assertTrue(torch.all(p1 == p2))


class TestHessian(unittest.TestCase):
    def test_nn_hessian(self) -> None:
        # given a single weight and bias, these are the functions which compute
        # the hessian matrix
        hessian_funcs = [
            [lambda x: 2 * x ** 2, lambda x: 2 * x],
            [lambda x: 2 * x, lambda x: 2]
        ]

        for _ in range(100):
            net = nn.Sequential(nn.Linear(1, 1, bias=True))
            for p in net:
                if hasattr(p, "weight") and p.weight is not None:
                    nn.init.normal_(p.weight)
                if hasattr(p, "bias") and p.bias is not None:
                    nn.init.normal_(p.bias)

                inputs = torch.randn(1, 1)
                out = net(inputs).squeeze()
                y = torch.randn(1)
                loss = ((out - y) ** 2).sum()

                hessian = full_hessian(loss, net, retain_graph=True)
                x = inputs.squeeze().item()
                for (row1, row2) in zip(hessian, hessian_funcs):
                    for (col1, col2) in zip(row1, row2):
                        self.assertAlmostEqual(col1.item(), col2(x), places=4)

    def test_nn_hessian_size(self) -> None:
        for _ in range(5):
            in_one, out_one = torch.randint(low=8, high=16, size=(2,))
            in_one, out_one = in_one.item(), out_one.item()
            net = nn.Sequential(nn.Linear(in_one, out_one, bias=True), nn.ReLU(inplace=True), nn.Linear(out_one, 1))
            params = sum([p.numel() for p in net.parameters()])

            inputs = torch.randn(32, in_one)
            out = net(inputs).squeeze()
            y = torch.randn(32)
            loss = ((out - y) ** 2).sum()

            hessian = full_hessian(loss, net, retain_graph=True)
            for n in hessian.size():
                self.assertEqual(n, params)

    def test_nn_hessian_inverse(self) -> None:
        for _ in range(5):
            in_one, out_one = torch.randint(low=10, high=20, size=(2,))
            in_one, out_one = in_one.item(), out_one.item()
            net = nn.Sequential(nn.Linear(in_one, out_one), nn.ReLU(inplace=True), nn.Linear(out_one, 1))
            opt = torch.optim.Adam(net.parameters(), weight_decay=1e-3)

            inputs = torch.randn(32, in_one)
            y = torch.randn(32)

            # test fails without first training the network. I wouldn't think this
            # should have an impact on the final covariance being PSD but it does
            for i in range(1000):
                out = net(inputs).squeeze()
                loss = ((out - y) ** 2).sum()
                opt.zero_grad()
                loss.backward()
                opt.step()

            opt.zero_grad()
            out = net(inputs).squeeze()
            loss = ((out - y) ** 2).sum()

            # NOTE:
            # - hessian doesn't need to be negative because we are working with the loss
            #   which finds the minimum of the log density, which distributes the negative into
            #   the Taylor series, leaving exp{-1/2(xAx)} which is exactly what we want
            params = sum([p.numel() for p in net.parameters()])
            hessian = full_hessian(loss, net, retain_graph=True)
            sigma = torch.inverse(hessian + torch.eye(params) * 1e-2 * params)

            mu = torch.zeros(params)
            N = torch.distributions.MultivariateNormal(mu, sigma)
            sample = N.sample()
            self.assertEqual(sample.size(0), params)
