import torch
import pandas as pd
# https://github.com/RUCAIBox/Language-Specific-Neurons/blob/main/identify.py

lang1 = 'python'
lang2 = 'php'
langs = [lang1,lang2]

modelpath = f"{lang2}"

def getresult(lang1, lang2, modelpath):

    top_rate = 0.01
    filter_rate = 0.95
    activation_bar_ratio = 0.95

    path = 'nerons/result_' + modelpath

    file1 = 'ave_' + modelpath + f'_{lang1}.result'
    file2 = 'ave_' + modelpath + f'_{lang2}.result'

    ave1 = torch.load(path + '/' + file1)
    ave1 = [i.unsqueeze(0) for i in ave1]
    ave1 = torch.cat(ave1,dim=0)
    ave2 = torch.load(path + '/' + file2)
    ave2 = [i.unsqueeze(0) for i in ave2]
    ave2 = torch.cat(ave2,dim=0)


    activation_probs = torch.stack([ave1,ave2], dim=-1)
    num_layers = activation_probs.size(0)

    normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True)
    normed_activation_probs[torch.isnan(normed_activation_probs)] = 0
    log_probs = torch.where(normed_activation_probs > 0, normed_activation_probs.log(), 0)
    entropy = -torch.sum(normed_activation_probs * log_probs, dim=-1)
    largest = False

    if torch.isnan(entropy).sum():
        print(torch.isnan(entropy).sum())
        raise ValueError

    flattened_probs = activation_probs.flatten()
    top_prob_value = flattened_probs.kthvalue(round(len(flattened_probs) * filter_rate)).values.item()
    print(top_prob_value)
    # dismiss the neruon if no language has an activation value over top 90%
    top_position = (activation_probs > top_prob_value).sum(dim=-1)
    entropy[top_position == 0] = -torch.inf if largest else torch.inf

    flattened_entropy = entropy.flatten()
    top_entropy_value = round(len(flattened_entropy) * top_rate)
    _, index = flattened_entropy.topk(top_entropy_value, largest=largest)
    row_index = index // entropy.size(1)
    col_index = index % entropy.size(1)
    selected_probs = activation_probs[row_index, col_index] # n x lang
    # for r, c in zip(row_index, col_index):
    #     print(r, c, activation_probs[r][c])

    print(selected_probs.size(0), torch.bincount(selected_probs.argmax(dim=-1)))
    selected_probs = selected_probs.transpose(0, 1)
    activation_bar = flattened_probs.kthvalue(round(len(flattened_probs) * activation_bar_ratio)).values.item()
    print((selected_probs > activation_bar).sum(dim=1).tolist())
    lang, indice = torch.where(selected_probs > activation_bar)

    merged_index = torch.stack((row_index, col_index), dim=-1)
    final_indice = []
    for _, index in enumerate(indice.split(torch.bincount(lang).tolist())):
        lang_index = [tuple(row.tolist()) for row in merged_index[index]]
        lang_index.sort()
        layer_index = [[] for _ in range(num_layers)]
        for l, h in lang_index:
            layer_index[l].append(h)
        for l, h in enumerate(layer_index):
            layer_index[l] = torch.tensor(h).long()
        final_indice.append(layer_index)
    return final_indice




results = []
final_indice = getresult(lang1, lang2, modelpath)
n = [0] * len(final_indice)
result = []
for i in range(len(final_indice)):
    for j in final_indice[i]:
        n[i] += len(j)
    result.append(n[i])
results.append(result)
df = pd.DataFrame(results)
df.to_csv('nerons/result_' + modelpath + '/result.csv')