import json
import os
from typing import List, Dict, Set
from collections import defaultdict
import pdb


def load_json_files(json_files):
    data = []
    for json_file in json_files:
        with open(json_file, 'r') as f:
            data.append(json.load(f))
    return data


def save_dict_to_json(data, json_file):
    with open(json_file, 'w') as f:
        json.dump(data, f)


def extract_top_indices_by_param(data):
    param_indices = defaultdict(list)
    for json_data in data:
        for param_name, layer_data in json_data.items():
            param_indices[param_name].append(set(layer_data['top_indices']))
    
    return param_indices


def extract_top_indices_scaler_values_and_params(data):
    all_indices_scaler_values_and_params = []
    for json_data in data:
        for param_name, layer_data in json_data.items():
            top_indices = layer_data['top_indices']
            top_scaler_values = layer_data['top_scaler_values']
            
            all_indices_scaler_values_and_params.extend(
                (index, value, param_name) for index, value in zip(top_indices, top_scaler_values)
            )
    
    return all_indices_scaler_values_and_params


def extract_unique_indices_scaler_values_and_params(data):
    unique_pairs = {}
    for json_data in data:
        for param_name, layer_data in json_data.items():
            top_indices = layer_data['top_indices']
            top_scaler_values = layer_data['top_scaler_values']
            
            for index, scaler_value in zip(top_indices, top_scaler_values):
                pair_key = (param_name, index)
                
                # Keep the pair with the highest scaler_value
                if pair_key not in unique_pairs or scaler_value > unique_pairs[pair_key]:
                    unique_pairs[pair_key] = scaler_value
    
    unique_indices_scaler_values_and_params = [
        (index, scaler_value, param_name) for (param_name, index), scaler_value in unique_pairs.items()
    ]
    
    return unique_indices_scaler_values_and_params


def compute_union_and_intersection_by_param(param_indices):
    results = {}
    for param_name, indices_list in param_indices.items():
        union_set = set()
        intersection_set = indices_list[0] if indices_list else set()
        
        for indices in indices_list:
            union_set.update(indices)
            intersection_set.intersection_update(indices)
        
        union_list = sorted(list(union_set), reverse=False)
        intersect_list = sorted(list(intersection_set), reverse=False)
        results[param_name] = {
            "union": union_list,
            "intersection": intersect_list
        }
    
    return results


def get_top_indices_scaler_values_and_params(all_indices_scaler_values_and_params, top_num):
    sorted_indices_scaler_values_and_params = sorted(all_indices_scaler_values_and_params, key=lambda x: x[1], reverse=True)
    top_indices_scaler_values_and_params = sorted_indices_scaler_values_and_params[:top_num]
    return top_indices_scaler_values_and_params


def main(json_files, union_json, intersect_json, global_top_json, top_num, is_global=True):
    json_data = load_json_files(json_files)
    
    if not is_global:
        param_indices = extract_top_indices_by_param(json_data)
        results = compute_union_and_intersection_by_param(param_indices)
        union_dict = {param_name: sets["union"] for param_name, sets in results.items()}
        intersection_dict = {param_name: sets["intersection"] for param_name, sets in results.items()}
        save_dict_to_json(union_dict, union_json)
        save_dict_to_json(intersection_dict, intersect_json)

        for param_name, sets in results.items():
            print(f"Parameter: {param_name}")
            print(f"  Union of top_indices: {sets['union']}")
            print(f"  Intersection of top_indices: {sets['intersection']}")
    else:
        if not global_top_json:
            raise ValueError("global_top_json cannot be None when is_global is True")
        
        all_indices_scaler_values_and_params = extract_unique_indices_scaler_values_and_params(json_data)
        results = get_top_indices_scaler_values_and_params(all_indices_scaler_values_and_params, top_num=top_num)
        top_global_dict = {}
        for idx, line in enumerate(results):
            param_name = line[2]
            row_index = line[0]
            scaler_value = line[1]
            top_global_dict.update({f"{idx}": [param_name, row_index, scaler_value]})
            
        save_dict_to_json(top_global_dict, global_top_json)
        
        for idx, top_indices_with_params in top_global_dict.items():
            print(f"Top {idx} parameter row indices with highest scaler values for parameter:\nparam_name: {top_indices_with_params[0]}\nrow_index: {top_indices_with_params[1]}\nscaler_value: {top_indices_with_params[2]}")



if __name__ == "__main__":

    llama3_code_top_params_files = [
        "llama3_top_param_info/code_top/infinity_code_sample_50_seed_42_all_top_100_info.json",
        "llama3_top_param_info/code_top/infinity_code_sample_50_seed_43_all_top_100_info.json",
        "llama3_top_param_info/code_top/infinity_code_sample_50_seed_44_all_top_100_info.json",
        "llama3_top_param_info/code_top/infinity_code_sample_50_seed_45_all_top_100_info.json",
        "llama3_top_param_info/code_top/infinity_code_sample_50_seed_46_all_top_100_info.json",
    ]

    mistral_math_top_param_files = [
        "mistral_top_param_info/math_top/infinity_math_sample_50_seed_42_all_top_info.json",
        "mistral_top_param_info/math_top/infinity_math_sample_50_seed_43_all_top_info.json",
        "mistral_top_param_info/math_top/infinity_math_sample_50_seed_44_all_top_info.json",
        "mistral_top_param_info/math_top/infinity_math_sample_50_seed_45_all_top_info.json",
        "mistral_top_param_info/math_top/infinity_math_sample_50_seed_46_all_top_info.json",
    ]

    qwen2_5_14b_csqa_top_param_files = [
        "qwen2_5_14b_top_param_info/csqa_top/csqa_sample_50_seed_42_all_top_100_info.json",
        "qwen2_5_14b_top_param_info/csqa_top/csqa_sample_50_seed_43_all_top_100_info.json",
        "qwen2_5_14b_top_param_info/csqa_top/csqa_sample_50_seed_44_all_top_100_info.json",
        "qwen2_5_14b_top_param_info/csqa_top/csqa_sample_50_seed_45_all_top_100_info.json",
        "qwen2_5_14b_top_param_info/csqa_top/csqa_sample_50_seed_46_all_top_100_info.json",
    ]
    
    top_num = 100
    task_name = "csqa"
    json_files = llama3_code_top_params_files
    model = f"llama3_top_param_info/{task_name}_top"
    union_json = f"{model}/{task_name}_top_{top_num}_params_union.json"
    intersect_json = f"{model}/{task_name}_top_{top_num}_params_intersect.json"
    global_json = f"{task_name}_global_top_{top_num}_params.json"
    main(json_files, union_json, intersect_json, global_top_json=global_json, top_num=top_num, is_global=False)
