import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from safetensors.numpy import load_file


def quantile_normalize(source, target):
    # Get ranks of source values (0 to 1)
    
    ranks = (source.argsort().argsort().float() / (len(source) - 1))
    # Use these ranks to sample from target distribution
    return torch.quantile(target, ranks)


# We collect the results from simulations

SCORES_FOLDER= f"results/kl_div_activations/70b/transcoder/"
skip_data=[]
for i in tqdm(range(0,2000)):
    try:
        all_data = pd.read_csv(f"{SCORES_FOLDER}{i}all_data.csv")
        all_data["sentence_idx"] = i
        skip_data.append(all_data)
    except:
        
        # skip the files that don't exist
        continue
skip_data = pd.concat(skip_data,ignore_index=True)



# some predictions failed and get assigned a negative quantile, we set their predicted activation to 0
skip_data["predicted_activation"] = skip_data["predicted_activation"].where(skip_data["expected_quantile"]<0,0)
skip_data["expected_quantile"] = skip_data["expected_quantile"].where(skip_data["expected_quantile"]<0,0)
#skip_data["activation"] = skip_data["activation"].where(skip_data["expected_quantile"]<0,0)

# how many different sentence_ids are there?
unique_sentence_ids = skip_data["sentence_idx"].unique()
print(len(unique_sentence_ids))
number_examples = len(unique_sentence_ids)
# how many examples are there?

print(len(skip_data))

# number of features per example
print(len(skip_data)/len(unique_sentence_ids))

layer = 15
all_locations = []
all_activations = []
ranges = ["0_3685","3686_7371","7372_11058","11059_14744","14745_18431"]

tokens = None
for valid_range in ranges:
    split_data = load_file(f"results/transcoder/latents/layers.15.mlp/{valid_range}.safetensors")
    activations = torch.tensor(split_data["activations"])
    locations = torch.tensor(split_data["locations"].astype(np.int32))
    locations[:,2] = locations[:,2]+int(valid_range.split("_")[0])
    all_locations.append(locations)
    all_activations.append(activations)
    if tokens is None:
        tokens = torch.tensor(split_data["tokens"])


locations = torch.cat(all_locations)
activation = torch.cat(all_activations)

feature_groups = dict(tuple(skip_data.groupby("feature")))

start_idx=0
new_data = []

for i in tqdm(range(0,18432)):
    if i not in feature_groups:
        continue
    feature_x = feature_groups[i]
   
    real_activations = torch.tensor(feature_x["activation"].values,dtype=torch.float32).cuda()
    # shuffle the real activations
    real_activations = torch.randperm(real_activations)
    # select only 10% of the data
    real_activations = real_activations[:int(len(real_activations)*0.1)]
    predicted_activations = torch.tensor(feature_x["predicted_activation"].values,dtype=torch.float32).cuda()

    
    normalized_activations = quantile_normalize(predicted_activations,real_activations)
    chosen_data = feature_x
    chosen_data["normalized_activation"] = normalized_activations.cpu().numpy()
    new_data.append(chosen_data)
    # drop the train entries
new_data = pd.concat(new_data)
# how many features are there?
print(len(new_data["feature"].unique()))

# drop text column if it exists
if "text" in new_data.columns:
    new_data.drop(columns=["text"],inplace=True)

new_data.to_csv(f"transcoder_activations.csv",index=False)
