import json
import math
import os.path as osp
import argparse

import matplotlib.pyplot as plt
import numpy as np

def sort_dict_by_value(d, largest_first=True):
    return {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=largest_first)}

def divide_dict(d, total_steps, ratio=0.5):
    value = total_steps * ratio
    return {k: v for k, v in d.items() if v > value}

def print_dict(d):
    for k, v in d.items():
        print(f'{k}: {v:.8f}')

def print_two_dict(d1, d2):
    for i, (k, v) in enumerate(d1.items()):
        d2_key = list(d2.keys())[i]
        d2_value = d2[d2_key]
        print(f'{k}: {v:.8f} | {d2_key}: {d2_value:.8f}')

def dict_values_avg(dict):
    return {k: sum(v)/len(v) for k, v in dict.items() if len(v) > 0}

def list_distance(list1, list2):
    diff = 0
    max = -1
    min = 1000000
    for pos, value in enumerate(list1):
        if value in list2:
            pos2 = list2.index(value)
            diff += abs(pos - pos2)
            if abs(pos - pos2) > max:
                max = abs(pos - pos2)
            if abs(pos - pos2) < min:
                min = abs(pos - pos2)
            # print(f'{value} in list2, distance: {abs(pos - pos2)}')
        else:
            diff += 0
    diff = diff / len(list1)
    return diff, max, min

def list_overlap(list1, list2):
    overlap = 0
    for value in list1:
        if value in list2:
            overlap += 1
    return overlap

def read_content(file_content, whether_downtimes=False):
    score_dict = file_content['lora_score']
    score_dict = sort_dict_by_value(score_dict)
    print_dict(score_dict)
    if whether_downtimes:
        print('-----------------')
        lower_bound_down_times = file_content['lowerbound_down_times']
        lower_bound_down_times = {k: score_dict[k] - v for k, v in lower_bound_down_times.items()}
        lower_bound_down_times = sort_dict_by_value(lower_bound_down_times)

        print_dict(lower_bound_down_times)
        print('-----------------')
        print(
            f'distance between lower_bound_down_times and lora_score: {list_distance(list(lower_bound_down_times.keys()), list(score_dict.keys()))}')
    return score_dict

def mix_lowerbound_down_times(file_dir, blk):
    with open(file_dir, 'r') as f:
        file_content = json.load(f)
        if blk:
            lowerbound_down_times = file_content['lowerbound_down_times']
            lowerbound_down_times = {k: v for k, v in lowerbound_down_times.items() if v > 0.0} # remove 0.0
            lowerbound_down_times = sort_dict_by_value(lowerbound_down_times)
        else:
            lowerbound_down_times = file_content['lowerbound_down_times']
            lowerbound_down_times = sort_dict_by_value(lowerbound_down_times)

    root_path = osp.dirname(output_dir)
    basename = osp.basename(output_dir)
    new_file = osp.join(root_path, f'mix_lowerbound_{basename}')
    with open(new_file, 'w') as f:
        json.dump(lowerbound_down_times, f)

    return lowerbound_down_times

def read_weight_score(file):
    with open(file, 'r') as f:
        file_content = json.load(f)
        weight_score = file_content['weight_score']
        weight_score = sort_dict_by_value(weight_score)

    return weight_score

def mix_lora_score(file_dir, blk):
    with open(file_dir, 'r') as f:
        file_content = json.load(f)
        if blk:
            norm_score_selected = file_content['lora_score']
            block_score_selected = file_content['block_score']
            norm_score_selected = {k: v for k, v in norm_score_selected.items() if v > 0.0}
            block_score_selected.update(norm_score_selected)
            block_score_selected = sort_dict_by_value(block_score_selected)
        else:
            block_score_selected = file_content['lora_score']
            block_score_selected = sort_dict_by_value(block_score_selected)

    root_path = osp.dirname(output_dir)
    basename = osp.basename(output_dir)
    new_file = osp.join(root_path, f'mix_{basename}')
    with open(new_file, 'w') as f:
        json.dump(block_score_selected, f)

    return block_score_selected

def cal_with_types(input_dicts, layers_split, type, largest_first=True):
    output = {}
    output_keys_dict = {}
    for i in range(len(layers_split) - 1):
        layers_split_value_l = layers_split[i]
        layers_split_value_r = layers_split[i + 1]
        for type_value in type:
            string = f'{layers_split_value_r}_{type_value}'
            for key, value in input_dicts.items():
                if len(key.split('.')) < 7:
                    output[key.split('.')[-1]] = [value]
                    output_keys_dict[key.split('.')[-1]] = [key]
                    continue

                layers_num = int(key.split('.')[4])
                t = key.split('.')[6]
                if layers_num < layers_split_value_r and layers_num >= layers_split_value_l and t == type_value:
                    if output.get(string, None) is None:
                        output[string] = []
                    output[string].append(value)
                    if output_keys_dict.get(string, None) is None:
                        output_keys_dict[string] = []
                    output_keys_dict[string].append(key)

    for key, value in output.items():
        output[key] = sum(value) / len(value)

    print('------------------------------------------------')
    output = sort_dict_by_value(output, largest_first=largest_first)
    print_dict(output)

    output_keys_dict_o = {}

    for key in output.keys():
        output_keys_dict_o[key] = output_keys_dict[key]

    return output, output_keys_dict_o

def get_freeze_layers_from_constant(args):
    fixed_layers = []
    with open(args.layer_json_file, 'r') as f:
        keys_list = json.load(f)
        keys = list(keys_list.keys())[:args.topK]
        for key in keys:
            fixed_layers.extend(keys_list[key])

    return fixed_layers

def get_layers_prob_distribution(args):
    with open(args.layer_json_file, 'r') as f:
        permutations = json.load(f)

def draw_the_distribution(values, title):
    plt.hist(values, bins=32, color='blue', edgecolor='black')

    # Add titles and labels
    plt.title(title)
    plt.xlabel('Value')
    plt.ylabel('Frequency')

    # Show the plot
    plt.show()

def transform_the_distribution(values, desired_mean, desired_variance):
    # Calculate the current mean and standard deviation
    current_mean = np.mean(values)
    current_std = np.std(values)

    # Standardize the values
    standardized_values = (values - current_mean) / current_std

    # Rescale the standardized values to the desired mean and variance
    desired_std = np.sqrt(desired_variance)
    transformed_values = (standardized_values * desired_std) + desired_mean

    return transformed_values

import random

def get_freeze_layers(args):
    if args.layer_json_file is not None:
        # if not args.full_finetune:
        if 'mix' in args.layer_json_file or 'lowerbound' in args.layer_json_file:
            largest_first = True
        elif 'gs' in args.layer_json_file:
            largest_first = False
        elif 'constant' in args.layer_json_file:
            fixed_layers = get_freeze_layers_from_constant(args)
            return fixed_layers
        else:
            raise ValueError(f'Unknown layer json file: {args.layer_json_file}')

        with open(args.layer_json_file, 'r') as f:
            score = json.load(f)
            topK = int(len(list(score.keys())) * args.divide_ratio)
            if args.first_ratio <= 0.0 and args.last_ratio <= 0.0:
                score = sort_dict_by_value(score, largest_first=largest_first)
                fixed_layers = list(score.keys())

            if args.Random:
                random.shuffle(fixed_layers)
                fixed_layers = fixed_layers[:topK]
            elif args.first_ratio > 0.0:
                fixed_layers = list(score.keys())[:topK]
            elif args.last_ratio > 0.0:
                fixed_layers = list(score.keys())[-topK:]
            else:
                fixed_layers = fixed_layers[:topK]
    else:
        fixed_layers = []

    return fixed_layers

def transfer_the_distribution(keys_dict):
    max = 1.0
    min = 0.0
    length = len(keys_dict)
    slots = 8
    prob_list = []
    for i in range(slots):
        prob_list.append(min + (max - min) * (i + 1) / slots)

    # transfer_list = [7, 4, 4, 2, 2, 2, 2, 2]
    transfer_list = [4, 3, 3, 3, 3, 3, 3, 3]

    total = sum(transfer_list)
    one_slot = length / total

    transfer_list = [math.floor(i * one_slot) for i in transfer_list]
    new_total = sum(transfer_list)

    remains = length - new_total
    for i in range(remains):
        transfer_list[i] += 1

    output = {}
    keys = list(keys_dict.keys())
    for i, value in enumerate(transfer_list):
        start = sum(transfer_list[:i])
        end = sum(transfer_list[:i+1])
        the_keys = keys[start:end]
        for key in the_keys:
            for k in keys_dict[key]:
                output[k] = prob_list[i]

    return output

if __name__ == '__main__':

    # output_dir_2 = 'skilled_scores/localization-further-llama2-7b-no_robots-bf16-ALL+LM-lr0.01-ep0.2-lrs-s1-usesigmoidTrue-iv5.0/lora_score_60.json'
    # output_dir = 'skilled_scores/localization-further-llama2-7b-no_robots-bf16-ALL+LM-lr0.001-ep0.2-lrs-s1-usesigmoidTrue-iv5.0/lora_score_60.json'
    output_dir = './results/localization/Localization-LIMA-llama2-7B-bf16-FTFalse-ALL+LM-lr0.001-ep20-s1997-localization-si0-lgi4-sgFalse-bwFalse/lora_score_640.json'
    # output_dir = './results/llama2-7b-LIMA-bf16-FTFalse-ALL+LM-lr0.001-ep20-s1997-localization-si8-lgi4-sgFalse-constant/lora_score_320.json'
    # output_dir_2 = './skilled_scores/llama2-7b-LIMA-bf16-FTFalse-ALL+LM-lr0.001-ep20-lrsconstant-s1997-localization-lgi4-further/lora_score_320.json'
    # output_dir_2 = './results/llama2-7b-LIMA-bf16-FTFalse-ALL+LM-lr0.001-ep20-s111-localization-si8-lgi2-sgFalse-bwFalse/lora_score_320.json'
    # output_dir_2 = './results/localization/Localization-alpaca-gpt4-llama2-7B-bf16-FTFalse-ALL+LM-lr0.001-ep2-s1997-localization-si0-lgi8-sgFalse-bwFalse/lora_score_3250.json'
    output_dir_2 = './results/localization/Localization-no_robots-llama2-7B-bf16-FTFalse-ALL+LM-lr0.001-ep4-s1997-localization-si0-lgi16-sgFalse-bwFalse/lora_score_1184.json'

    # output_dir_2 = 'skilled_scores/localization-further-llama2-7b-LIMA-bf16-ALL+LM-lr0.001-ep5-lrs-s1997-usesigmoidTrue-iv5.0-LIMA/lora_score_160.json'

    # output_dir = 'skilled_scores/localization-further-llama2-7b-LIMA-bf16-ALL+LM-lr0.001-ep6-lrs-s1997-usesigmoidFalse-iv1.0/lora_score_192.json'
    # output_dir_2 = 'skilled_scores/localization-further-llama2-7b-LIMA-bf16-ALL+LM-lr0.001-ep6-lrs-s1997-usesigmoidFalse-iv0.5-LIMA/lora_score_192.json'

    parser = argparse.ArgumentParser()
    parser.add_argument('--layer_json_file', type=str, default='./skilled_scores/llama2-7b-no_robots-si148-lgi12-ep6_constant.json')
    parser.add_argument('--topK', type=int, default=2)
    parser.add_argument('--slice', type=int, default=32)
    args = parser.parse_args()

    largest_first_1 = False
    largest_first_2 = False

    # save_path = f'./skilled_scores/llama2-7b-LIMA-further-ep5-lr0.001-iv5.0-slice{args.slice}-constant.json'
    # save_path_2 = f'./skilled_scores/llama2-7b-no_robots-further-ep5-lr0.001-iv5.0-slice{args.slice}-constant.json'

    save_path = f'./skilled_scores/llama2-7b-LIMA-si0-lgi4-ep20-slice{slice}-constant.json'
    save_path_2 = f'./skilled_scores/llama2-7b-no_robots-si0-lgi16-ep4-slice{slice}-constant.json'


    # save_path_distribution = f'./skilled_scores/llama2-7b-LIMA-si0-lgi16-ep2-slice{args.slice}-uniform-distribution.json'
    # save_path_distribution_2 = f'./skilled_scores/llama2-7b-no_robots-si0-lgi16-ep4-slice{args.slice}-uniform-distribution.json'
    # output_dir_2 = './skilled_scores/llama2-7b-no_robots-bf16-FTFalse-ALL+LM-lr0.001-ep6-lrscosine-s1997-localization-lgi12-further/lora_score_888.json'
    #'./skilled_scores/llama2-13b-no_robots-bf16-FTFalse-ALL-lr0.001-ep6-s1997-localization-si0-lgi6-sgFalse-bwFalse/lowerbound_445.json'
    #'./skilled_scores/llama2-7b-no_robots-bf16-FTFalse-ALL-lr0.001-ep6-s1997-localization-si0-lgi6-sgFalse-bwFalse/lowerbound_445.json'
    #'./skilled_scores/llama2-7b-no_robots-bf16-FTFalse-ALL-lr0.001-ep6-s1997-localization-si0-lgi6-sgFalse-bwFalse/lowerbound_297.json'
    #'./skilled_scores/llama2-7b-no_robots-bf16-FTFalse-ALL-lr0.001-ep6-s1997-localization-si0-lgi6-sgFalse-bwFalse/lowerbound_297.json'

    # output_dir = './skilled_scores/llama2-7b-LIMA-bf16-FTFalse-ALL-lr0.001-ep20-s1997-localization-si0-lgi6-sgFalse-bwFalse/lowerbound_'
    # output_dir_2 = './skilled_scores/llama2-7b-LIMA-bf16-FTFalse-ALL-lr0.001-ep20-s1997-localization-si0-lgi6-sgFalse-bwFalse-seed111/lowerbound_'

    # iteration = [65, 130, 195, 243, 320]
    #
    # for i in range(len(iteration)):
    #     for j in range(i + 1, len(iteration)):
    #         first = iteration[i]
    #         second = iteration[i]
    #         file = output_dir + str(first) + '.json'
    #         file_2 = output_dir_2 + str(second) + '.json'
    #         with open(file, 'r') as f:
    #             file_content = json.load(f)
    #             score_selected = file_content
    #             score_selected = sort_dict_by_value(score_selected)
    #             # nonzero_ratio = len([v for v in score_selected.values() if v > 0.0]) / len(score_selected)
    #
    #         with open(file_2, 'r') as f:
    #             file_content_2 = json.load(f)
    #             score_selected2 = file_content_2
    #             score_selected2 = sort_dict_by_value(score_selected2)
    #             # nonzero_ratio_2 = len([v for v in score_selected2.values() if v > 0.0]) / len(score_selected2)
    #
    #         # print_two_dict(score_selected, score_selected2)
    #         print(f'distance between {first} and {second}: {list_distance(list(score_selected.keys()), list(score_selected2.keys()))}')

    with open(output_dir, 'r') as f:
        score_selected = json.load(f)
        score_selected = score_selected['lora_score']
        # score_selected = score_selected['lowerbound_down_times']
        score_selected = sort_dict_by_value(score_selected)
        nonzero_ratio = len([v for v in score_selected.values() if v > 0.0]) / len(score_selected)

    with open(output_dir_2, 'r') as f:
        score_selected2 = json.load(f)
        score_selected2 = score_selected2['lora_score']
        # score_selected2 = score_selected2['lowerbound_down_times']
        # score_selected2 = file_content_2['lowerbound_down_times's]
        score_selected2 = sort_dict_by_value(score_selected2)
        nonzero_ratio_2 = len([v for v in score_selected2.values() if v > 0.0]) / len(score_selected2)


    print_two_dict(score_selected, score_selected2)
    print(f'distance between glr2.0 and glr4.0: {list_distance(list(score_selected.keys()), list(score_selected2.keys()))}')
    print(f'nonzero ratio of glr2.0: {nonzero_ratio}')
    print(f'nonzero ratio of glr4.0: {nonzero_ratio_2}')

    print(args.slice)
    if args.slice == 1:
        layers_split = [0, 32]
    # layers_split = [0, 32]
    elif args.slice == 2:
        layers_split = [0, 16, 32]
    # layers_split = [0, 16, 32]
    elif args.slice == 4:
        layers_split = [0,8,16,24,32]
    elif args.slice == 8:
        layers_split = [0,4,8,12,16,20,24,28,32]
    elif args.slice == 16:
        layers_split = [0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32]
    elif args.slice == 32:
        layers_split = [i for i in range(0, 33)]
    else:
        raise ValueError('Unknown slice')
    # layers_split = [4, 8, 12, 16, 20, 24, 28, 32]
    type = ['v_proj', 'q_proj', 'k_proj', 'o_proj', 'down_proj', 'up_proj']

    score_selected = sort_dict_by_value(score_selected, largest_first=largest_first_1)
    types1, keys_dict1 = cal_with_types(score_selected, layers_split, type, largest_first=largest_first_1)
    # reverse the dict sequence

    score_selected2 = sort_dict_by_value(score_selected2, largest_first=largest_first_2)
    types2, keys_dict2 = cal_with_types(score_selected2, layers_split, type, largest_first=largest_first_2)
    # reverse the dict sequence

    with open(save_path, 'w') as f:
        json.dump(keys_dict1, f)

    with open(save_path_2, 'w') as f:
        json.dump(keys_dict2, f)

    for k, v in keys_dict2.items():
        print(f'{k}: {v}')

    print('------------------------------------------------')
    print(f'length of keys_dict1: {len(keys_dict1)}')
    print_two_dict(types1, types2)
    print(f'distance between glr2.0 and glr4.0: {list_distance(list(types1.keys()), list(types2.keys()))}')

    fixed_layers = get_freeze_layers(args)
    print(fixed_layers)

    # types1 = np.array(list(types1.values()))
    # types2 = np.array(list(types2.values()))
    #
    # draw_the_distribution(types1, 'LIMA_T1')
    # draw_the_distribution(types2, 'No_Robots_T1')
    #
    # types1 = transform_the_distribution(types1, 0.5, 0.2)
    # types2 = transform_the_distribution(types2, 0.5, 0.2)
    #
    # draw_the_distribution(types1, 'LIMA_T2')
    # draw_the_distribution(types2, 'No_Robots_T2')

    # print('------------------------------------------------')
    # output_2 = transfer_the_distribution(keys_dict2)
    # print_dict(output_2)
    #
    # with open(save_path_distribution_2, 'w') as f:
    #     json.dump(output_2, f)

    # weight_score = read_weight_score(output_dir)
    # weight_score_2 = read_weight_score(output_dir_2)
    # print_two_dict(weight_score, weight_score_2)
    # print(
    #     f'distance between glr2.0 and glr4.0: {list_distance(list(weight_score.keys()), list(weight_score_2.keys()))}')
    # print(f'length of weight_score: {len(weight_score)}')

    # if 'blkTrue' in output_dir:
    #     blk = True
    # elif 'blkFalse' in output_dir:
    #     blk = False
    # else:
    #     raise ValueError('blk not found')
    #
    #
    # prefix = 'lora_score'
    # # content_prefix = 'block_score'
    # content_prefix = 'lowerbound_down_times'
    # filename = f'{prefix}.json'
    #
    # block_score_selected = mix_lora_score(output_dir, blk)
    # block_score_selected_2 = mix_lora_score(output_dir_2, blk)
    # #
    # print_two_dict(block_score_selected, block_score_selected_2)
    # print(
    #     f'distance between glr2.0 and glr4.0: {list_distance(list(block_score_selected.keys()), list(block_score_selected_2.keys()))}')
    #
    # lowbound_down_times = mix_lowerbound_down_times(output_dir, blk)
    # lowbound_down_times_2 = mix_lowerbound_down_times(output_dir_2, blk)
    #
    # keys = lowbound_down_times.keys()
    # p = [pos for pos, value in enumerate(keys) if 'norm' in value]
    #
    #
    # print_two_dict(lowbound_down_times, lowbound_down_times_2)
    # print(
    #     f'distance between glr2.0 and glr4.0: {list_distance(list(lowbound_down_times.keys()), list(lowbound_down_times_2.keys()))}')
    # print(f'Last Norm Layer {p[-1]}')
    # print(f'total length: {len(lowbound_down_times)}')
    # if 'lora' not in output_dir:
    #     file = osp.join(output_dir, filename)
    # else:
    #     file = output_dir
    #
    # if 'lora' not in output_dir_2:
    #     file_2 = osp.join(output_dir_2, filename)
    # else:
    #     file_2 = output_dir_2
    # # file_3 = osp.join(output_dir_3, filename)
    #
    # with open(file, 'r') as f:
    #     file_content = json.load(f)
    #
    #     score_selected = file_content[content_prefix]
    #     if content_prefix == 'lora_gamma_value_list':
    #         score_selected = dict_values_avg(score_selected)
    #         score_selected = sort_dict_by_value(score_selected, largest_first=False)
    #     else:
    #         score_selected = sort_dict_by_value(score_selected)
    #
    # root_path = osp.dirname(output_dir)
    # basename = osp.basename(output_dir)
    # new_file = osp.join(root_path, f'lowerbound_{basename}')
    # with open(new_file, 'w') as f:
    #     json.dump(score_selected, f)
    #
    # with open(file_2, 'r') as f:
    #     file_content_2 = json.load(f)
    #     score_selected2 = file_content_2[content_prefix]
    #     if content_prefix == 'lora_gamma_value_list':
    #         score_selected2 = dict_values_avg(score_selected2)
    #         score_selected2 = sort_dict_by_value(score_selected2, largest_first=False)
    #     else:
    #         score_selected2 = sort_dict_by_value(score_selected2)
    #
    # root_path = osp.dirname(output_dir_2)
    # basename = osp.basename(output_dir_2)
    # new_file = osp.join(root_path, f'lowerbound_{basename}')
    # with open(new_file, 'w') as f:
    #     json.dump(score_selected2, f)
    #
    # print(list(score_selected.keys())[42:])
    # #
    # print(f'length of score_selected: {len(score_selected)}')
    # print_two_dict(score_selected, score_selected2)
    # print('-----------------')
    # print(f'distance between glr2.0 and glr4.0: {list_distance(list(score_selected.keys()), list(score_selected2.keys()))}')

        # print_dict(score_selected)
        # print('-----------------')
        # print(f'divide_ration:{0.1}: {len(divide_dict(score_selected, 500, ratio=0.1))}')
        # print(f'divide_ration:{0.2}: {len(divide_dict(score_selected, 500, ratio=0.2))}')
        # print(f'divide_ration:{0.25}: {len(divide_dict(score_selected, 500, ratio=0.25))}')
        # print(f'divide_ration:{0.15}: {len(divide_dict(score_selected, 500, ratio=0.15))}')
        # print(f'divide_ration:{0.12}: {len(divide_dict(score_selected, 500, ratio=0.12))}')
        # print(f'divide_ration:{0.17}: {len(divide_dict(score_selected, 500, ratio=0.17))}')
        # print(f'divide_ration:{0.05}: {len(divide_dict(score_selected, 500, ratio=0.05))}')
        # print(f'divide_ration:{0.07}: {len(divide_dict(score_selected, 500, ratio=0.07))}')
        # print(f'divide_ration:{0.09}: {len(divide_dict(score_selected, 500, ratio=0.09))}')
        # print(f'divide_ration:{0.0}: {len(divide_dict(score_selected, 500, ratio=0.0))}')

    # with open(file_2, 'r') as f:
    #     file_content_2 = json.load(f)
    #     score_selected2 = file_content_2[prefix]
    #     score_selected2 = sort_dict_by_value(score_selected2)
    #
    # print_two_dict(score_selected, score_selected2)
    # print('-----------------')
    # print(f'distance between lgi4 and lgi2: {list_distance(list(score_selected.keys()), list(score_selected2.keys()))}')
    # print(f'overlap between lgi4 and lgi2 of topK 30: {list_overlap(list(score_selected.keys())[:30], list(score_selected2.keys())[:30])}')
        # score_selected = divide_dict(score_selected, 500, ratio=0.13)
        # score_selected = list(score_selected.keys())
        # print('-----------------')
        # print(len(score_selected))
    #     if file_content.get('lora_gamma_value_list', None):
    #         lora_gamma_value_list = file_content['lora_gamma_value_list']
    #         lora_gamma_value_list = dict_values_avg(lora_gamma_value_list)
    #         lora_gamma_value_list = sort_dict_by_value(lora_gamma_value_list, largest_first=False)
    #         if len(lora_gamma_value_list) > 0:
    #             score_dict = lora_gamma_value_list
    #             print_dict(score_dict)
    #         else:
    #             score_dict = read_content(file_content, whether_downtimes=True)
    #         # score_dict = file_content['lora_score']
    #         # score_dict = sort_dict_by_value(score_dict)
    #         # print_dict(score_dict)
    #
    #     else:
    #         score_dict = read_content(file_content, whether_downtimes=True)
    #
    #     # key_list = list(score_dict.keys())
    #     # key_list = sorted(key_list)
    #     # import random
    #     # random.shuffle(key_list)
    #     # key_list = {k: 0.0 for k in key_list}
    #     # save_path = './output/Random/lora_score.json'
    #     #
    #     # with open(save_path, 'w') as f:
    #     #     file_content['lora_score'] = key_list
    #     #     json.dump(key_list, f, indent=4)
    #
    #     keys_list_1 = list(score_dict.keys())
    #     # half_dict = half_dict(score_dict, steps)
    #     print('-----------------')
    #     # print(f'The length of half dict is {len(half_dict)}')
    #     print('-----------------')
    #
    #     # diff_up = file_content['diff_up']
    #     # diff_down = file_content['diff_down']
    #     #
    #     # print_dict(dict_values_avg(file_content['diff_up']))
    #     # print('-----------------')
    #     # print_dict(dict_values_avg(file_content['diff_down']))
    #
    # with open(file_2, 'r') as f:
    #     file_content_2 = json.load(f)
    #     score_dict_2 = file_content_2['lora_score']
    #     score_dict_2 = sort_dict_by_value(score_dict_2)
    #     keys_list_2 = list(score_dict_2.keys())
    #
    # with open(file_3, 'r') as f:
    #     file_content_3 = json.load(f)
    #     score_dict_3 = file_content_3['lora_score']
    #     score_dict_3 = sort_dict_by_value(score_dict_3)
    #     keys_list_3 = list(score_dict_3.keys())
    #
    # print('-----------------')
    # print(f'length of list 1: {len(keys_list_1)}')
    # print(f'distance between 10000 and 5000: {list_distance(keys_list_1, keys_list_2)}')
    # print(f'distance between 10000 and 2000: {list_distance(keys_list_1, keys_list_3)}')
    # print(f'distance between 5000 and 2000: {list_distance(keys_list_2, keys_list_3)}')
    #
    # print(f'overlap between 10000 and 5000 of top 40: {list_overlap(keys_list_1[:40], keys_list_2[:40])}')
    # print(f'overlap between 10000 and 2000 of top 40: {list_overlap(keys_list_1[:40], keys_list_3[:40])}')
    # print(f'overlap between 5000 and 2000 of top 40: {list_overlap(keys_list_2[:40], keys_list_3[:40])}')
    #
    # print(f'overlap between 10000 and 5000 of top 100: {list_overlap(keys_list_1[:100], keys_list_2[:100])}')
    # print(f'overlap between 10000 and 2000 of top 100: {list_overlap(keys_list_1[:100], keys_list_3[:100])}')
    # print(f'overlap between 5000 and 2000 of top 100: {list_overlap(keys_list_2[:100], keys_list_3[:100])}')