'''Pytorch implementation of CKA without trace fn'''
import os
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader

def zero_center(features):
    return features - torch.mean(features, 0, keepdim=True)

def cka_notrace(features_x, features_y, path=None, logger=None, remove=None, return_norm=False):
    """Tensor implementation of XXt, YYt without the trace operation and HSIC"""
    features_x = zero_center(features_x)
    features_y = zero_center(features_y)

    xxt = torch.mm(features_x, torch.transpose(features_x, 0, 1))
    yyt = torch.mm(features_y, torch.transpose(features_y, 0, 1))

    if isinstance(remove, list): # removes all indexes in the list
        for idx_remove in remove:
            xxt[:, int(idx_remove)] = 0
            yyt[int(idx_remove), :] = 0
        torch.diagonal(xxt).zero_()
        torch.diagonal(yyt).zero_()

    similarity = torch.mm(xxt, yyt)

    # calculate norm (to double-check with CKA)
    normalization_x = torch.linalg.norm(torch.mm(torch.transpose(features_x, 0, 1), features_x))
    normalization_y = torch.linalg.norm(torch.mm(torch.transpose(features_y, 0, 1), features_y))

    norm  = normalization_x * normalization_y

    if path is not None:
        print(f'Saving matrices under {path}')
        # torch.save(remove, os.path.join(path, "valid_indexes.pt"))
        # torch.save(xxt.cpu().detach(), os.path.join(path, "sim_xxt.pt"))
        # torch.save(yyt.cpu().detach(), os.path.join(path, "sim_yyt.pt"))
        torch.save(similarity.cpu().detach(), os.path.join(path, "final_sim_xxtyyt.pt"))
        torch.save(norm.cpu().detach(), os.path.join(path, "norm.pt"))

    if logger is not None:
        logger.info(f'Our version of CKA: {torch.trace(similarity.cpu().detach())/norm.cpu().detach()}')
    else:
        print(f'Our version of CKA: {torch.trace(similarity.cpu().detach())/norm.cpu().detach()}')

    if return_norm:
        return similarity, norm    
    return similarity

def calculate_cosine_matrix(x, y):
    dot = torch.mm(x, torch.transpose(y, 0, 1))
    norm_x = torch.linalg.norm(x, axis=1).repeat(x.shape[0])
    norm_y = torch.linalg.norm(y, axis=1).repeat(y.shape[0],1).T.reshape(-1)
    norm = (norm_x * norm_y).reshape(y.shape[0], y.shape[0]).T
    return dot, norm

def pnka(features_x, features_y, path=None, logger=None):
    """Tensor implementation of cos(XXt, YYt)"""
    features_x = zero_center(features_x)
    features_y = zero_center(features_y)

    xxt = torch.mm(features_x, torch.transpose(features_x, 0, 1)) # nxn
    yyt = torch.mm(features_y, torch.transpose(features_y, 0, 1)) # nxn
    xxt_yyt, cos_norm = calculate_cosine_matrix(xxt, yyt)
    pnka_scores = (xxt_yyt / cos_norm).float()
    pnka_scores = torch.where(cos_norm == 0., torch.zeros(cos_norm.shape), pnka_scores)

    if path is not None:
        print(f'Saving matrices under {path}')
        torch.save(pnka_scores.cpu().detach(), os.path.join(path, "pnka_scores.pt"))

    if logger is not None:
        logger.info(f'Trace of cos: {torch.trace(pnka_scores.cpu().detach())}')
        logger.info(f'Avg trace of cos: {torch.trace(pnka_scores.cpu().detach())/pnka_scores.shape[0]}')
    else:
        print(f'Trace of cos: {torch.trace(pnka_scores.cpu().detach())}')
        print(f'Avg trace of cos: {torch.trace(pnka_scores.cpu().detach())/pnka_scores.shape[0]}')

    return pnka_scores

def cosine_similarity(features_x, features_y, keep_only_idx=None, path=None):
    """Calculates cosine similarity between tensors"""
    if keep_only_idx is not None:
        features_x = torch.index_select(features_x, 0, 
            torch.from_numpy(np.array(keep_only_idx)))
        features_y = torch.index_select(features_y, 0, 
            torch.from_numpy(np.array(keep_only_idx)))
    cos, norm = calculate_cosine_matrix(features_x, features_y)
    result = cos/norm
    if path is not None:
        torch.save(result.cpu().detach(), os.path.join(path, "sim.pt"))
    return result

def l2_distance(features_x, features_y, keep_only_idx=None, path=None):
    """Calculates cosine similarity between tensors"""
    if keep_only_idx is not None:
        features_x = torch.index_select(features_x, 0, 
            torch.from_numpy(np.array(keep_only_idx)))
        features_y = torch.index_select(features_y, 0, 
            torch.from_numpy(np.array(keep_only_idx)))
    result = torch.linalg.norm(features_x - features_y, dim=1)
    if path is not None:
        torch.save(result.cpu().detach(), os.path.join(path, "l2_dist.pt"))



### Landmarks + efficient version of our method

def zero_center(features):
    return features - torch.mean(features, 0, keepdim=True)

def compute_kernel(X_centered, device, L=None):
    L_centered = X_centered
    if L is not None:
        L_centered = zero_center(L)
    L_centered = L_centered.to(device)
    X_centered = X_centered.to(device)
    similarity_kernel = X_centered @ L_centered.T
    return similarity_kernel

def pointwise_similarity(loader_x, loader_y, device, landmarks_x, landmarks_y):
    similarities = []
    for _, (repr_x, repr_y) in tqdm(enumerate(zip(loader_x, loader_y)), total=len(loader_x), desc="Loader 1"):
        similarity_X = compute_kernel(repr_x, device, landmarks_x)
        similarity_Y = compute_kernel(repr_y, device, landmarks_y)
        similarity = torch.cosine_similarity(similarity_X, similarity_Y, dim=-1)
        similarities.append(similarity)

    return torch.cat(similarities)

def efficient_pnka(features_x, features_y, batch_size, device,
                          landmarks_x, landmarks_y, path=None, logger=None):
    features_x = zero_center(features_x)
    features_y = zero_center(features_y)
    loader_x = DataLoader(features_x, batch_size=batch_size, shuffle=False)
    loader_y = DataLoader(features_y, batch_size=batch_size, shuffle=False)
    pnka_scores = pointwise_similarity(
        loader_x, loader_y, device, landmarks_x, landmarks_y).detach().cpu().to(torch.float32)
    
    if path is not None:
        print(f'Saving matrices under {path}')
        torch.save(pnka_scores, os.path.join(path, "pnka_scores.pt"))

    if logger is not None:
        logger.info(f'Sum: {torch.sum(pnka_scores.cpu().detach())}')
        logger.info(f'Avg sum: {torch.sum(pnka_scores.cpu().detach())/pnka_scores.shape[0]}')
    else:
        print(f'Sum: {torch.sum(pnka_scores.cpu().detach())}')
        print(f'Avg sum: {torch.sum(pnka_scores.cpu().detach())/pnka_scores.shape[0]}')

    return pnka_scores

### Landmarks + efficient version of CKA

def cka_pointwise_similarity(loader_x, loader_y, device, landmarks_x, landmarks_y):
    similarities = []
    for _, (repr_x, repr_y) in tqdm(enumerate(zip(loader_x, loader_y)), total=len(loader_x), desc="Loader 1"):
        similarity_X = compute_kernel(repr_x, device, landmarks_x)
        similarity_Y = compute_kernel(repr_y, device, landmarks_y)
        similarity = torch.mm(similarity_X, similarity_Y.T)
        print('similarity_X', similarity_X.shape)
        print('similarity_Y', similarity_Y.shape)
        print('similarity', similarity.shape)
        similarities.append(similarity)

    return torch.cat(similarities)

def efficient_cka(features_x, features_y, batch_size, device,
                          landmarks_x, landmarks_y, path=None, logger=None):
    features_x = zero_center(features_x)
    features_y = zero_center(features_y)
    loader_x = DataLoader(features_x, batch_size=batch_size, shuffle=False)
    loader_y = DataLoader(features_y, batch_size=batch_size, shuffle=False)
    cka = cka_pointwise_similarity(
        loader_x, loader_y, device, landmarks_x, landmarks_y).detach().cpu().to(torch.float32)
    print('cka mx', cka.shape)
    normalization_x = torch.linalg.norm(torch.mm(torch.transpose(features_x, 0, 1), landmarks_x))
    normalization_y = torch.linalg.norm(torch.mm(torch.transpose(features_y, 0, 1), landmarks_y))

    norm  = normalization_x * normalization_y

    if path is not None:
        print(f'Saving matrices under {path}')
        torch.save(cka, os.path.join(path, "final_sim_xxtyyt.pt"))
        torch.save(norm, os.path.join(path, "norm.pt"))

    if logger is not None:
        logger.info(f'CKA: {torch.trace(cka)/norm}')
    else:
        print(f'CKA: {torch.trace(cka)/norm}')

    return cka
