import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import random

def find_neu(data, threshold=6, apply_abs=True):
    mean = np.mean(data)
    std = np.std(data)

    func = np.abs if apply_abs else lambda x: x
    return np.argwhere(func(data - mean) > threshold * std).tolist()

def get_data(ind_list, data_name="gsm8k", rato=100):
    data = None
    for i in ind_list:
        if data is None:
            data = np.array(torch.load(f"/share/projset/knowledge-neurons/ckpt/{data_name}/te{i}.pt")) * rato
        else:
            data = data + np.array(torch.load(f"/share/projset/knowledge-neurons/ckpt/{data_name}/te{i}.pt")) * rato
    return data / (len(ind_list)*rato)

def cal_acc(set_1, set_2):
    te_1 = set([f"{te}" for te in set_1])
    te_2 = set([f"{te}" for te in set_2])  
    te_and = te_1 & te_2
    te_or = te_1 | te_2
    print("num -> set 1: {} | set 2: {}".format(len(te_1), len(te_2)))
    print("avg_rato -> {:.4f}".format((len(te_and)/len(te_1) + len(te_and)/len(te_2))/2))
    print("IoU: {:.4f}".format(len(te_and)/len(te_or)))
    print("rato -> set 1: {:.4f} | set 2: {:.4f}".format(len(te_and)/len(te_1), len(te_and)/len(te_2)))

if __name__ == "__main__":
    data_name_1 = "emotion"
    te_lis = [i for i in range(24)]
    half_ = int(len(te_lis)/2)
    # random.shuffle(te_lis)
    print(te_lis)

    data_1 = get_data(ind_list=te_lis[:half_], data_name=data_name_1)
    data_2 = get_data(ind_list=te_lis[half_:2*half_], data_name=data_name_1)

    data_emo = get_data(ind_list=te_lis, data_name=data_name_1)
    te_med = [i for i in range(22)]
    random.shuffle(te_med)
    # te_med = [0,5,4,1,2,3]
    # "code2k"
    # "emotion"
    rat = 50
    data_name = "code2k"
    data_med = get_data(ind_list=te_med, data_name=data_name, rato=rat)

    half_ = int(len(te_med)/2)
    data_3 = get_data(ind_list=te_med[:half_], data_name=data_name, rato=rat)
    data_4 = get_data(ind_list=te_med[half_:2*half_], data_name=data_name, rato=rat)

    print(f"total neurrons: {data_1.size}")
    print("=========")

    print(data_name_1)
    cal_acc(find_neu(data_1), find_neu(data_2))
    all_data = get_data(ind_list=te_lis, data_name=data_name_1)
    print("neuron rato: {:.4f}".format(len(find_neu(all_data))/352256))

    print("=========")
    print(data_name)
    cal_acc(find_neu(data_3), find_neu(data_4))
    all_data = get_data(ind_list=te_med, data_name=data_name, rato=rat)
    print("neuron rato: {:.4f}".format(len(find_neu(all_data))/352256))

    print("=========")
    name_math_cho = "gsm8k_cho"
    print(name_math_cho)
    te_li_math_cho = [i for i in range(16)]
    random.shuffle(te_li_math_cho)
    half_mc = int(len(te_li_math_cho)/2)
    data_7 = get_data(ind_list=te_li_math_cho[:half_mc], data_name=name_math_cho, rato=100)
    data_8 = get_data(ind_list=te_li_math_cho[half_mc:2*half_mc], data_name=name_math_cho, rato=100)
    cal_acc(find_neu(data_7), find_neu(data_8))
    all_data = get_data(ind_list=te_li_math_cho, data_name=name_math_cho)
    print("neuron rato: {:.4f}".format(len(find_neu(all_data))/352256))

    print("=========")
    name_math_cho = "gsm8k_cho_new"
    print(name_math_cho)
    te_li_math_cho = [i for i in range(14)]
    random.shuffle(te_li_math_cho)
    half_mc = int(len(te_li_math_cho)/2)
    data_7 = get_data(ind_list=te_li_math_cho[:half_mc], data_name=name_math_cho, rato=100)
    data_8 = get_data(ind_list=te_li_math_cho[half_mc:2*half_mc], data_name=name_math_cho, rato=100)
    cal_acc(find_neu(data_7), find_neu(data_8))

    print("=========")
    name_math_cho = "meta_math"
    print(name_math_cho)
    te_li_math_cho = [i for i in range(7)]
    random.shuffle(te_li_math_cho)
    half_mc = int(len(te_li_math_cho)/2)
    data_7 = get_data(ind_list=te_li_math_cho[:half_mc], data_name=name_math_cho, rato=100)
    data_8 = get_data(ind_list=te_li_math_cho[half_mc:], data_name=name_math_cho, rato=100)
    cal_acc(find_neu(data_7), find_neu(data_8))
    data_9 = get_data(ind_list=te_li_math_cho, data_name=name_math_cho, rato=100)
    print("neuron rato: {:.4f}".format(len(find_neu(data_9))/352256))

    print("=========")
    name_math_cho = "emotion"
    print(name_math_cho)
    te_li_math_cho = [i for i in range(8)]
    random.shuffle(te_li_math_cho)
    half_mc = int(len(te_li_math_cho)/2)
    data_7 = get_data(ind_list=te_li_math_cho[:half_mc], data_name=name_math_cho, rato=100)
    data_8 = get_data(ind_list=te_li_math_cho[half_mc:], data_name=name_math_cho, rato=100)
    cal_acc(find_neu(data_7), find_neu(data_8))
    data_9 = get_data(ind_list=te_li_math_cho, data_name=name_math_cho, rato=100)
    print("neuron rato: {:.4f}".format(len(find_neu(data_9))/352256))
