"""
Can be used to reproduce figure A.8 and A.9.
"""

import torch
import matplotlib.pyplot as plt
import scipy
import numpy as np
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

def get_effective_rank(matrix, return_singular_values=False):
    S = torch.linalg.svdvals(matrix)
    if return_singular_values:
        singular_values = S.detach().clone()
    S /= torch.sum(S)
    erank = torch.e ** scipy.stats.entropy(S.detach())
    if return_singular_values:
        return np.nan_to_num(erank), singular_values
    return np.nan_to_num(erank)

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root='~/Documents/Torch_Dataset/', train=True,
    download=False, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=1,
    shuffle=False,
)

correlations = []
for i in tqdm(range(10_000)):
    full = []
    sampled = []
    X, y = next(iter(trainloader))

    x_axis = range(1, 14)#[1, 3, 5, 7, 9]
    for out_channel in x_axis:
        #out_channel = 16
        kernel_size = 1
        conv1 = torch.nn.Conv2d(3, out_channel, kernel_size)
        output = conv1(X)
        size = output.shape[-1]
        output = torch.transpose(output, 1, 3).flatten(0, 2)
        output = output.detach()
        rank = get_effective_rank(output)/out_channel
        full.append(rank)


        rank = 0
        for i in range(32):
            X, y = next(iter(trainloader))
            output = conv1(X)
            output = torch.transpose(output, 1, 3).flatten(0, 2)
            start = np.random.randint(0, output.shape[0]//size - 1) * 3 if output.shape[0]//size - 1 > 0 else 0
            end = start + out_channel
            output = output[start:end]
            output = output.detach()
            rank += get_effective_rank(output)/out_channel
        rank /= 32
        sampled.append(rank)
        
    correlations.append(scipy.stats.spearmanr(sampled, full).statistic)

print("Average correlation:", np.mean(correlations))

plt.xlabel("Kernel size", fontsize=14)
#plt.xlabel("Number of output channels", fontsize=14)
plt.ylabel("Effective rank / max effective rank", fontsize=14)
plt.scatter(x_axis, full, color="#0077BB")
plt.plot(x_axis, full, color="#0077BB")
plt.show()
#plt.savefig("full_rank.pdf")
plt.clf()
plt.xlabel("Kernel size", fontsize=14)
#plt.xlabel("Number of output channels", fontsize=14)
plt.ylabel("Effective rank / max effective rank", fontsize=14)
plt.scatter(x_axis, sampled, color="#EE7733")
plt.plot(x_axis, sampled, color="#EE7733")
#plt.savefig("sampled_rank.pdf")
plt.show()
