import torch
import os
import numpy as np
import ripser
import pickle
from multiprocessing import Pool

# NOTE: Update to your directory where the activation files are stored
ACTIVATION_DIR = os.path.join(os.getcwd(), 'data/llama_3_8B/test')
# Loading activations
clean_activations = torch.load(os.path.join(ACTIVATION_DIR, 'clean_hidden_states_0_1000_20240508_040518.pt'))
poisoned_activations = torch.load(os.path.join(ACTIVATION_DIR, 'poisoned_hidden_states_0_1000_20240508_101212.pt'))

# Extract the second entries
clean_data = clean_activations[1] - clean_activations[0]  # Shape: [1000, 32, 4096]
poisoned_data = poisoned_activations[1]- poisoned_activations[0] # Shape: [1000, 32, 4096]

# Convert tensors to numpy arrays for compatibility with TDA libraries
clean_data_np = clean_data.float().numpy()
poisoned_data_np = poisoned_data.float().numpy()

for i in range(31):
    init = i
    end = i+1
    
    start_layer_clean = clean_data_np[:, init, :]
    end_layer_clean = clean_data_np[:, end, :]
    start_layer_poisoned = poisoned_data_np[:, init, :]
    end_layer_poisoned = poisoned_data_np[:, end, :]
    
    def process_clean(num):
        clean_layers = np.vstack([start_layer_clean[num,:], end_layer_clean[num,:]])
        return ripser.ripser(clean_layers.T)["dgms"]

    def process_poisoned(num):
        poisoned_layers = np.vstack([start_layer_poisoned[num,:], end_layer_poisoned[num,:]])
        return ripser.ripser(poisoned_layers.T)["dgms"]

    lis = [x for x in range(1000)] 
    
    with Pool(32) as pool:
        col_clean_pd = pool.map(process_clean, lis)
        
    print("layers "+str(i+1)+" to "+str(i+2)+" clean data complete.")
    
    with Pool(32) as pool:
        col_poisoned_pd = pool.map(process_poisoned, lis)
        
    print("layers "+str(i+1)+" to "+str(i+2)+" poisoned data complete.")
    
    collection_pds = col_clean_pd, col_poisoned_pd
    
    with open("model_PD/llama_3_8B/Original/layers_"+str(i+1)+"_"+str(i+2)+"_pd_colletions.pkl","wb") as file:
        pickle.dump(collection_pds, file)
        
    print("layers "+str(i+1)+" to "+str(i+2)+" complete.")