import torch
from torch import nn
from tqdm import tqdm

DOUBLE_INFO = torch.finfo(torch.double)
JITTERS = [0, DOUBLE_INFO.tiny] + [10 ** exp for exp in range(-308, 0, 1)]


def centered_cov_torch(x):
    n = x.shape[0]
    res = 1 / (n - 1) * x.t().mm(x)
    return res


def get_embeddings(
    net, loader: torch.utils.data.DataLoader, num_dim: int, dtype, device, storage_device,
):
    num_samples = len(loader.dataset)
    embeddings = torch.empty((num_samples, num_dim), dtype=dtype, device=storage_device)
    labels = torch.empty(num_samples, dtype=torch.int, device=storage_device)

    with torch.no_grad():
        start = 0
        for data, label in tqdm(loader):
            data = data.to(device)
            label = label.to(device)

            if isinstance(net, nn.DataParallel):
                if hasattr(net.module, "feature"):
                    out = net.module(data)
                    out = net.module.feature
                else: # make forward hook to get features
                    activation = {}
                    def get_activation(name):
                        def hook(model, input, output):
                            activation[name] = input[0].clone().detach()
                        return hook
                    net.module.linear.register_forward_hook(get_activation('feature'))
                    _ = net.module(data)
                    out = activation['feature']
            else:
                if hasattr(net, "feature"):
                    out = net(data)
                    out = net.feature
                else: # make forward hook to get features
                    activation = {}
                    def get_activation(name):
                        def hook(model, input, output):
                            activation[name] = input[0].clone().detach()
                        return hook
                    net.linear.register_forward_hook(get_activation('feature'))
                    _ = net(data)
                    out = activation['feature']

            end = start + len(data)
            embeddings[start:end].copy_(out, non_blocking=True)
            labels[start:end].copy_(label, non_blocking=True)
            start = end

    return embeddings, labels


def gmm_forward(net, gaussians_model, data_B_X):

    if isinstance(net, nn.DataParallel):
        if hasattr(net.module, "feature"):
            features_B_Z = net.module(data_B_X)
            features_B_Z = net.module.feature
        else: # make forward hook to get features
            activation = {}
            def get_activation(name):
                def hook(model, input, output):
                    activation[name] = input[0].clone().detach()
                return hook
            net.module.linear.register_forward_hook(get_activation('feature'))
            _ = net.module(data_B_X)
            features_B_Z = activation['feature']
    else:
        if hasattr(net, "feature"):
            features_B_Z = net(data_B_X)
            features_B_Z = net.feature
        else: # make forward hook to get features
            activation = {}
            def get_activation(name):
                def hook(model, input, output):
                    activation[name] = input[0].clone().detach()
                return hook
            net.linear.register_forward_hook(get_activation('feature'))
            _ = net(data_B_X)
            features_B_Z = activation['feature']

    log_probs_B_Y = gaussians_model.log_prob(features_B_Z[:, None, :])

    return log_probs_B_Y


def gmm_evaluate(net, gaussians_model, loader, device, num_classes, storage_device):

    num_samples = len(loader.dataset)
    # set dtype to float64. fixes error in original code where logits_N_C was float32, causing nan values later on
    logits_N_C = torch.empty((num_samples, num_classes), dtype=torch.float64, device=storage_device)
    labels_N = torch.empty(num_samples, dtype=torch.int, device=storage_device)

    with torch.no_grad():
        start = 0
        for data, label in tqdm(loader):
            data = data.to(device)
            label = label.to(device)

            logit_B_C = gmm_forward(net, gaussians_model, data)

            end = start + len(data)
            logits_N_C[start:end].copy_(logit_B_C, non_blocking=True)
            labels_N[start:end].copy_(label, non_blocking=True)
            start = end

    return logits_N_C, labels_N


def gmm_get_logits(gmm, embeddings):

    log_probs_B_Y = gmm.log_prob(embeddings[:, None, :])
    return log_probs_B_Y


def gmm_fit(embeddings, labels, num_classes):
    with torch.no_grad():
        classwise_mean_features = torch.stack([torch.mean(embeddings[labels == c], dim=0) for c in range(num_classes)])
        classwise_cov_features = torch.stack(
            [centered_cov_torch(embeddings[labels == c] - classwise_mean_features[c]) for c in range(num_classes)]
        )

    with torch.no_grad():
        for jitter_eps in JITTERS:
            try:
                jitter = jitter_eps * torch.eye(
                    classwise_cov_features.shape[1], device=classwise_cov_features.device,
                ).unsqueeze(0)
                gmm = torch.distributions.MultivariateNormal(
                    loc=classwise_mean_features, covariance_matrix=(classwise_cov_features + jitter),
                )
            except RuntimeError as e:
                if "cholesky" in str(e):
                    continue
            except ValueError as e:
                if "The parameter covariance_matrix has invalid values" in str(e):
                    continue
            break

    return gmm, jitter_eps
