import torch
import os
import numpy as np
import ripser
import pickle
from multiprocessing import Pool # for multiprocessing

# NOTE: Update to your directory where the activation files are stored
ACTIVATION_DIR = os.path.join(os.getcwd(), 'activations')

# Loading mistral 7B activations
clean_activations = torch.load(os.path.join(ACTIVATION_DIR, 'mistral_7B', 'clean_hidden_states_0_1000_20240422_185013.pt'))
poisoned_activations = torch.load(os.path.join(ACTIVATION_DIR, 'mistral_7B', 'poisoned_hidden_states_0_1000_20240423_011421.pt'))


# Extract the second entries
clean_data = clean_activations[1] - clean_activations[0]  # Shape: [1000, 32, 4096]
poisoned_data = poisoned_activations[1]- poisoned_activations[0] # Shape: [1000, 32, 4096]

# Convert tensors to numpy arrays for compatibility with TDA libraries
clean_data_np = clean_data.float().numpy()
poisoned_data_np = poisoned_data.float().numpy()

def standardize_matrix(matrix):
    row_means = np.mean(matrix, axis=1, keepdims=True)
    row_stds = np.std(matrix, axis=1, keepdims=True)
    standardized_matrix = (matrix - row_means) / row_stds
    return standardized_matrix

def permute_matrix_rows_independently(matrix):
    return standardize_matrix(np.array([np.random.permutation(row) for row in matrix]))

for i in range(28):
    init = i
    end = i+4
    
    start_layer_clean =  clean_data_np[:, init, :]
    end_layer_clean = clean_data_np[:, end, :]
    start_layer_poisoned = poisoned_data_np[:, init, :]
    end_layer_poisoned = poisoned_data_np[:, end, :]
    
    def process_clean(num):
        clean_layers = standardize_matrix(np.vstack([start_layer_clean[num,:], end_layer_clean[num,:]]))
        return ripser.ripser(clean_layers.T)["dgms"]

    def process_poisoned(num):
        poisoned_layers = standardize_matrix(np.vstack([start_layer_poisoned[num,:], end_layer_poisoned[num,:]]))
        return ripser.ripser(poisoned_layers.T)["dgms"]

    lis = [x for x in range(1000)] 
    
    with Pool(8) as pool:
        col_clean_pd = pool.map(process_clean, lis)
        
    print("layers "+str(init)+" to "+str(end)+" clean data complete.")
    
    with Pool(8) as pool:
        col_poisoned_pd = pool.map(process_poisoned, lis)
        
    print("layers "+str(init)+" to "+str(end)+" poisoned data complete.")
    
    collection_pds = col_clean_pd, col_poisoned_pd
    
    with open("Non_consec/3step/Scaled/layers_"+str(init)+"_"+str(end)+"_pd_colletions.pkl","wb") as file:
        pickle.dump(collection_pds, file)
        
    print("layers "+str(init)+" to "+str(end)+" complete.")