import numpy as np
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
from sklearn.manifold import MDS
import torch
import torch.nn as nn
import scipy.stats as stats
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import NearestNeighbors

def landmark_mds(X, n_components=25):

    N = len(X)
    n_landmarks = int(N * 0.2)
    landmark_indices = np.random.choice(N, n_landmarks, replace=False)
    X_landmarks = X[landmark_indices]

    dist_matrix = pairwise_distances(X_landmarks, metric="euclidean")

    mds = MDS(n_components=n_components, dissimilarity="precomputed", random_state=42)
    X_landmarks_transformed = mds.fit_transform(dist_matrix)

    nn = NearestNeighbors(n_neighbors=5, metric="euclidean").fit(X_landmarks)
    distances, indices = nn.kneighbors(X)

    X_transformed = np.mean(X_landmarks_transformed[indices], axis=1)

    return X_transformed

def calculate_dv_correlation(activation1, activation2, categories, dim=None, mode='pearson', dim_red='pca', shrink=False):
    num_categories = len(categories)
    lda_score1 = 0
    lda_score2 = 0
    dim1 = activation1[list(categories)[0]].shape[1]
    dim2 = activation2[list(categories)[0]].shape[1]
    dv_correlations = np.zeros((num_categories, num_categories))
    for i in range(num_categories):
        for j in range(i+1, num_categories):
            category1 = list(categories)[i]
            category2 = list(categories)[j]
            data1 = activation1[category1]
            data2 = activation1[category2]
            N1 = data1.shape[0]
            X = np.concatenate([data1, data2], axis=0)

            if dim is not None and dim1 >= dim:
                if dim_red == 'pca':
                    pca = PCA(n_components=dim)
                    pca.fit(X)
                    X = pca.transform(X)
                    data1 = pca.transform(data1)
                    data2 = pca.transform(data2)
                elif dim_red == 'mds':
                    mds = MDS(n_components=dim)
                    X = mds.fit_transform(X)
                    data1 = X[:N1]
                    data2 = X[N1:]
                elif dim_red == 'landmark_mds':
                    X = landmark_mds(X, n_components=dim)
                    data1 = X[:N1]
                    data2 = X[N1:]
            
            y = np.concatenate([np.zeros(data1.shape[0]), np.ones(data2.shape[0])])
            if shrink:
                lda = LinearDiscriminantAnalysis(solver='eigen', shrinkage='auto')
            else:
                lda = LinearDiscriminantAnalysis()
            lda.fit(X, y)
            subject1_data1_lda = lda.transform(data1)
            subject1_data2_lda = lda.transform(data2)
            lda_score1 += lda.score(X, y)

            data1 = activation2[category1]
            data2 = activation2[category2]
            X = np.concatenate([data1, data2], axis=0)

            if dim is not None and dim2 >= dim:
                if dim_red == 'pca':
                    pca = PCA(n_components=dim)
                    pca.fit(X)
                    X = pca.transform(X)
                    data1 = pca.transform(data1)
                    data2 = pca.transform(data2)
                elif dim_red == 'mds':
                    mds = MDS(n_components=dim)
                    X = mds.fit_transform(X)
                    data1 = X[:N1]
                    data2 = X[N1:]
                elif dim_red == 'landmark_mds':
                    X = landmark_mds(X, n_components=dim)
                    data1 = X[:N1]
                    data2 = X[N1:]
            
            y = np.concatenate([np.zeros(data1.shape[0]), np.ones(data2.shape[0])])
            if shrink:
                lda = LinearDiscriminantAnalysis(solver='eigen', shrinkage='auto')
            else:
                lda = LinearDiscriminantAnalysis()
            lda.fit(X, y)
            subject2_data1_lda = lda.transform(data1)
            subject2_data2_lda = lda.transform(data2)
            lda_score2 += lda.score(X, y)
            
            if mode == 'pearson':
                dv_correlations[i,j] = np.corrcoef(subject1_data1_lda.T, subject2_data1_lda.T)[0, 1]
                dv_correlations[j,i] = np.corrcoef(subject1_data2_lda.T, subject2_data2_lda.T)[0, 1]
            elif mode == 'spearman':
                dv_correlations[i,j] = stats.spearmanr(subject1_data1_lda, subject2_data1_lda)[0]
                dv_correlations[j,i] = stats.spearmanr(subject1_data2_lda, subject2_data2_lda)[0]
    lda_score1 /= num_categories * (num_categories - 1) / 2
    lda_score2 /= num_categories * (num_categories - 1) / 2
    np.fill_diagonal(dv_correlations, np.nan)
    return dv_correlations, lda_score1, lda_score2

def split_representations(data, categories, split=None, random=True):
    split1 = {}
    split2 = {}
    num_neurons = data[list(data.keys())[0]].shape[1]
    if random:
        shuffled_order = np.random.permutation(num_neurons)
    for category in categories:
        if random:
            data[category] = data[category][:,shuffled_order]
        if split is None:
            split1[category] = data[category][:,:data[category].shape[1]//2]
            split2[category] = data[category][:,data[category].shape[1]//2:]
        else:
            split1[category] = data[category][:,:split]
            split2[category] = data[category][:,split:]
    return split1, split2

def calculate_normalized_dv_correlation(activation1, activation2, categories, dim=None, mode='pearson', dim_red='pca',shrink=False):
    split11, split12 = split_representations(activation1, categories)
    split21, split22 = split_representations(activation2, categories)
    dv_correlations_11_21,_ , _ = calculate_dv_correlation(split11, split21, categories, dim, mode, dim_red, shrink)
    dv_correlations_11_22,_ , _ = calculate_dv_correlation(split11, split22, categories, dim, mode, dim_red, shrink)
    dv_correlations_12_21,_ , _ = calculate_dv_correlation(split12, split21, categories, dim, mode, dim_red, shrink)
    dv_correlations_12_22,_ , _ = calculate_dv_correlation(split12, split22, categories, dim, mode, dim_red, shrink)
    dv_correlations1_split,_,_ = calculate_dv_correlation(split11, split12, categories, dim, mode, dim_red, shrink)
    dv_correlations2_split,_,_ = calculate_dv_correlation(split21, split22, categories, dim, mode, dim_red, shrink)
    dv_correlations_cross = np.power(np.abs(dv_correlations_11_21 * dv_correlations_12_22 * dv_correlations_11_22 * dv_correlations_12_21), 0.25)
    dv_correlations_split = np.power(np.abs(dv_correlations1_split * dv_correlations2_split), 0.5)
    dv_correlations_normalized = np.divide(
        dv_correlations_cross,
        dv_correlations_split,
        out=np.full_like(dv_correlations_cross, np.nan),
        where=~np.isnan(dv_correlations_split) & (dv_correlations_split != 0)
    )
    return dv_correlations_normalized