import torch 
import torch.nn.functional as F
import numpy as np

from utils.experiments_utils import calculate_tsne, get_2d_plot, get_2d_seq_points_plot, plot_tsne

def off_diagonal_values(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


def calculate_covariance_average(representations):
    cov = _calculate_cov_matrix(representations)
    d = cov.shape[0]
    return off_diagonal_values(torch.corrcoef(cov)).abs().sum().div(2 * d * (d - 1))


def _calculate_covariance_average(cov):
    d = cov.shape[0]
    return off_diagonal_values(torch.corrcoef(cov)).abs().sum().div(2 * d * (d - 1))


def _calculate_log_singular_value(cov):
    return np.log(np.linalg.svd(cov, compute_uv=False))


def _calculate_cov_matrix(representations):
    representations = representations - representations.mean(dim=0) 
    return (representations.T @ representations) / (representations.shape[0] - 1)


def calculate_covariance_rank(representations):
    cov = _calculate_cov_matrix(representations)
    return _calculate_covariance_rank(cov)


def _calculate_covariance_rank(cov):
    return np.linalg.matrix_rank(cov, np.exp(-13))


def calculate_avg_std(representations):
    return torch.mean(representations.std(dim=0)).item()

def calculate_log_singular_value(representations):
    cov = _calculate_cov_matrix(representations)
    return _calculate_log_singular_value(cov)
    

def encoder_singular_values_plot(encoder, batch):
    with torch.no_grad():
        representations = encoder(batch).cpu()
        datapoints = calculate_log_singular_value(representations)
    return datapoints

def eval_encoder(encoder, batch):
    info = {}
    with torch.no_grad():
        representations = encoder(batch).cpu()
        cov = _calculate_cov_matrix(representations)
        info["covariance"] = _calculate_covariance_average(cov)
        info["std"] = calculate_avg_std(representations)
        info["rank"] = _calculate_covariance_rank(cov)
    return info

def visualize_representations(representations, n_epochs):
    tsne_res = calculate_tsne(np.concatenate(representations))
    return get_2d_plot(tsne_res, [f"Epoch {i}" for i in range(n_epochs)])

def visualize_sequential_representations(representations, n_steps):
    tsne_res = calculate_tsne(np.concatenate(representations))
    return get_2d_seq_points_plot(tsne_res, n_steps)