

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn

import torch.nn.functional as F
import numpy as np

from sklearn.cluster import SpectralClustering
from sklearn.metrics.cluster import adjusted_rand_score, rand_score


from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score


def softmax_without_diagonal_exp(input_matrix, dim=1):
    # Exponentiate the values of the input tensor
    exp_input = torch.exp(input_matrix)

    # Normalize while zeroing out diagonal elements
    softmax_result = normalize_with_diagonal_zero(exp_input, dim=dim)

    return softmax_result


def normalize_with_diagonal_zero(x, dim = 1, eps = 1e-5):

    x = x - torch.diag(torch.diagonal(x))

    return (x+eps) / torch.sum(x+eps, dim=dim, keepdim=True)


def normalize(x, dim = 1, eps = 1e-5):
    
    return (x+eps) / torch.sum(x+eps, dim=dim, keepdim=True)



def create_adjacency_matrix(labels):
    n = len(labels)

    # create an n-by-n adjacency matrix initialized to zero
    adj_matrix = torch.zeros((n, n))

    # fill in the adjacency matrix based on the label matches
    for i in range(n):
        for j in range(n):
            if labels[i] == labels[j]:
                adj_matrix[i,j] = 1

    return adj_matrix

def spectral_clustering(affinity_matrix, true_labels, n_clusters):
    # apply spectral clustering to the affinity matrix
    clustering = SpectralClustering(n_clusters=n_clusters, affinity='precomputed').fit(affinity_matrix.T + affinity_matrix)

    # get the predicted cluster labels
    predicted_labels = clustering.labels_

    # compute the Adjusted Rand Index between the predicted and true labels
    ari = adjusted_rand_score(true_labels, predicted_labels)
    ri = rand_score(true_labels, predicted_labels)

    return predicted_labels, ari, ri

def extract_feature(model, train_loader, device):
    y_array = []
    x_array = []
    embedding_array = []
    for i, (x, y) in enumerate(train_loader):
        y_array.append(y.detach().cpu().numpy())
        x_array.append(x.detach().cpu().numpy())

        x = x.to(device)
        # Forward pass
        output = model(x)
        embedding = output[0]
        embedding_array.append(embedding.detach().cpu().numpy())


    all_images = np.concatenate(x_array, axis=0)
    all_labels = np.concatenate(y_array, axis=0)
    all_embeddings = np.concatenate(embedding_array, axis = 0)

    return all_images, all_labels, all_embeddings
