import torch
import torch.nn.functional as F
import numpy as np
n, over_zero = [], []
for lang in ['QA', 'MATH', 'RLHF']:
    data = torch.load(f'/home/ssliang/unlearning/data/7bhf_activation_{lang}')
    n.append(data['n'])
    over_zero.append(data['over_zero'])

n = torch.tensor(n)
over_zero = torch.stack(over_zero, dim=-1)

num_layers, intermediate_size, lang_num = over_zero.size()

def activation():
    top_rate = 0.01
    filter_rate = 0.95
    activation_bar_ratio = 0.90
    activation_probs = over_zero / n # layer x inter x lang_num
    print('prob',activation_probs)
    """
    with open("/home/ssliang/unlearning/data/activation_results.txt", "w") as f:
       for line in activation_probs:
          line=line.detach().cpu().numpy()
          for prob in line:
             f.write(str(prob)+'\n')
    """
    sum1=0
    sum2=0
    sum3=0
    for line in activation_probs:
        line=line.detach().cpu().numpy()
        for prob in line:
            sum1+=prob[0]
            sum2+=prob[1]
            sum3+=prob[2]
        
    print('sum',sum1,' ',sum2,' ',sum3)
    val1=[]
    val2=[]
    val3=[]
    for i in range(len(activation_probs)):
          for prob in activation_probs[i]:
             prob[0]=prob[0]*num_layers*intermediate_size/sum1
             prob[1]=prob[1]*num_layers*intermediate_size/sum2
             prob[2]=prob[2]*num_layers*intermediate_size/sum3
          topk_vals, topk_indices = activation_probs[i].topk(50, dim=0, largest=True, sorted=True)
          #print('topk_vals',topk_vals,' ',topk_vals.shape)
          for j in range(50):
             val1.append(str(i)+' '+str(topk_indices[j][0].detach().cpu().numpy())+' '+str(topk_vals[j][0].detach().cpu().numpy())+' '+str(topk_vals[j][1].detach().cpu().numpy())+' '+str(topk_vals[j][2].detach().cpu().numpy()))
             val2.append(str(i)+' '+str(topk_indices[j][1].detach().cpu().numpy())+' '+str(topk_vals[j][0].detach().cpu().numpy())+' '+str(topk_vals[j][1].detach().cpu().numpy())+' '+str(topk_vals[j][2].detach().cpu().numpy()))
             val3.append(str(i)+' '+str(topk_indices[j][2].detach().cpu().numpy())+' '+str(topk_vals[j][0].detach().cpu().numpy())+' '+str(topk_vals[j][1].detach().cpu().numpy())+' '+str(topk_vals[j][2].detach().cpu().numpy()))
    with open("/home/ssliang/unlearning/data/activation_results_QA.txt", "w") as f:
        for line in val1:
           f.write(str(line)+'\n')
    with open("/home/ssliang/unlearning/data/activation_results_MATH.txt", "w") as f:
        for line in val2:
           f.write(str(line)+'\n')
    with open("/home/ssliang/unlearning/data/activation_results_RLHF.txt", "w") as f:
        for line in val3:
           f.write(str(line)+'\n')
    normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True)
    normed_activation_probs[torch.isnan(normed_activation_probs)] = 0
    log_probs = torch.where(normed_activation_probs > 0, normed_activation_probs.log(), 0)
    entropy = -torch.sum(normed_activation_probs * log_probs, dim=-1)
    largest = False
    
    if torch.isnan(entropy).sum():
        print(torch.isnan(entropy).sum())
        raise ValueError
    
    flattened_probs = activation_probs.flatten()
    top_prob_value = flattened_probs.kthvalue(round(len(flattened_probs) * filter_rate)).values.item()
    print(top_prob_value)
    # dismiss the neruon if no language has an activation value over top 90%
    top_position = (activation_probs > top_prob_value).sum(dim=-1)
    entropy[top_position == 0] = -torch.inf if largest else torch.inf

    flattened_entropy = entropy.flatten()
    top_entropy_value = round(len(flattened_entropy) * top_rate)
    _, index = flattened_entropy.topk(top_entropy_value, largest=largest)
    row_index = index // entropy.size(1)
    col_index = index % entropy.size(1)
    selected_probs = activation_probs[row_index, col_index] # n x lang
    # for r, c in zip(row_index, col_index):
    #     print(r, c, activation_probs[r][c])

    print(selected_probs.size(0), torch.bincount(selected_probs.argmax(dim=-1)))
    selected_probs = selected_probs.transpose(0, 1)
    #print('probs',selected_probs)
    
    activation_bar = flattened_probs.kthvalue(round(len(flattened_probs) * activation_bar_ratio)).values.item()
    print((selected_probs > activation_bar).sum(dim=1).tolist())
    lang, indice = torch.where(selected_probs > activation_bar)
    print('lang',lang,'indice',indice)
    merged_index = torch.stack((row_index, col_index), dim=-1)
    final_indice = []
    for _, index in enumerate(indice.split(torch.bincount(lang).tolist())):
        lang_index = [tuple(row.tolist()) for row in merged_index[index]]
        lang_index.sort()
        layer_index = [[] for _ in range(num_layers)]
        for l, h in lang_index:
            layer_index[l].append(h)
        for l, h in enumerate(layer_index):
            layer_index[l] = torch.tensor(h).long()
        final_indice.append(layer_index)
    """
    data1=[]
    data2=[]
    data3=[]
    
    for i in range(len(final_indice[0])):
         zsre=final_indice[0][i].cpu().detach().numpy()
         newdata=[]
         for j in range(len(zsre)):
             data1[i][zsre[j]]=1
    print(data1)
    for i in range(len(final_indice[1])):
         math=final_indice[1][i].cpu().detach().numpy()
         for j in range(len(math)):
              data2[i][math[j]]=1
    for i in range(len(final_indice[2])):
         rlhf=final_indice[2][i].cpu().detach().numpy()
         for j in range(len(rlhf)):
              data3[i][rlhf[j]]=1
    
    file = open( "/home/ssliang/unlearning/data/neurons_zsre.txt", 'w')
    for fp in final_indice[0]:
        fp=fp.cpu().detach().numpy()
        file.write(str(fp))
        file.write('\n')
    file.close()               
    file = open( "/home/ssliang/unlearning/data/neurons_math.txt", 'w')
    for fp in final_indice[1]:
        fp=fp.cpu().detach().numpy()
        file.write(str(fp))
        file.write('\n')
    file.close()    
    file = open( "/home/ssliang/unlearning/data/neurons_rlhf.txt", 'w')
    for fp in final_indice[2]:
        fp=fp.cpu().detach().numpy()
        file.write(str(fp))
        file.write('\n')
    file.close()   
    """
    neurons_zsre=final_indice[0]
    neurons_math=final_indice[1]
    neurons_rlhf=final_indice[2]
    #print('rlhf',neurons_rlhf)
    overlap12=[]
    overlap13=[]
    overlap23=[]
    
    for i in range(len(neurons_zsre)):
        newoverlap=[]
        for j in range(len(neurons_zsre[i])):
           for k in range(len(neurons_math[i])):
               if neurons_zsre[i][j]==neurons_math[i][k]:
                  #print('neuron',neurons_zsre[i][j],' ',neurons_math[i][k])
                  newoverlap.append(str(i)+' '+str(neurons_zsre[i][j].cpu().detach().numpy()))
        overlap12.append(newoverlap)
    #print(overlap12)               
    for i in range(len(neurons_math)):
        newoverlap=[]
        for j in range(len(neurons_math[i])):
           for k in range(len(neurons_rlhf[i])):
               if neurons_math[i][j]==neurons_rlhf[i][k]:
                  newoverlap.append(str(i)+' '+str(neurons_math[i][j].cpu().detach().numpy()))
        overlap13.append(newoverlap)  
    #print(overlap13)            
    for i in range(len(neurons_zsre)):
        newoverlap=[]
        for j in range(len(neurons_zsre[i])):
           for k in range(len(neurons_rlhf[i])):
               if neurons_zsre[i][j]==neurons_rlhf[i][k]:
                  newoverlap.append(str(i)+' '+str(neurons_zsre[i][j].cpu().detach().numpy()))
        overlap23.append(newoverlap)
    #print(overlap23) 
    act1=0
    neurons1=0
    for i in range(len(final_indice[0])): 
       for j in range(len(final_indice[0][i])):
          act1+=1
    for i in range(len(overlap12)):
       for j in range(len(overlap12[i])):
          neurons1+=1
    score1=neurons1/act1
    print('score1',score1)                                 
activation()
