import json, os, random, sys
import argparse
import yaml

from topiocqa import compute_score
from topiocqa import filter_standard
# SMN = ""
# ROLLOUT_N = 1
FILTER = False

def load_json(input_path):
    a = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            a.append(json.loads(line))
    return a

def get_avg_score(prompt_list):
    # Calculate average score for each sample, samples with same prompt_idx are treated as the same
    new_json_list = []
    index_hash = {}
    total_score = [] # Record total score for each sample, used to calculate average score
    max_score = [] # Record highest score for each sample
    min_score = [] # Record lowest score for each sample
    cnt = 0
    diff_score_count = [] # Record the number of different scores for each sample, in dictionary format
    for i, item in enumerate(prompt_list):
        prompt_idx = item['prompt_idx']
        if prompt_idx not in index_hash: # Initialize
            index_hash[prompt_idx] = cnt
            cnt += 1
            new_json_list.append(item)
            total_score.append(item['_score'])
            max_score.append(item['_score'])
            min_score.append(item['_score'])
            diff_score_count.append({
                f'{item["_score"]}': 1
            })
        else:
            total_score[index_hash[prompt_idx]] += item['_score']
            max_score[index_hash[prompt_idx]] = max(max_score[index_hash[prompt_idx]], item['_score'])
            min_score[index_hash[prompt_idx]] = min(min_score[index_hash[prompt_idx]], item['_score'])
            if f'{item["_score"]}' not in diff_score_count[index_hash[prompt_idx]]:
                diff_score_count[index_hash[prompt_idx]][f'{item["_score"]}'] = 1
            else:
                diff_score_count[index_hash[prompt_idx]][f'{item["_score"]}'] += 1
    original_length = len(index_hash)
    # print(original_length)
    # print(len(new_json_list))
    assert original_length == len(prompt_list) // ROLLOUT_N
    for i in range(original_length):
        new_json_list[i]['_avg_score'] = total_score[i] / ROLLOUT_N
        new_json_list[i]['_max_score'] = max_score[i]
        new_json_list[i]['_min_score'] = min_score[i]
        new_json_list[i]['_diff_score_count'] = len(diff_score_count[i])
    # Finally calculate an average score for all samples, this dataset just needs to have scores
    avg_score = 0
    for i in range(original_length):
        if new_json_list[i]['_rank'] != -1:
            avg_score += 1
    ratio = avg_score / original_length
    print("Total length:", original_length)
    print("Recall@100 recall rate: ", ratio)
    return new_json_list

# def get_score(prompt_list):
#     for item in prompt_list:
#         res = compute_score("", item[SMN], item['gt'])
#         # if res != 0:
#         #     print(res)
#         item['_score'] = res['score']
#     return prompt_list
# Change to concurrent
import concurrent.futures

def get_score(prompt_list, max_workers=500):
    def compute_single_score(item):
        res = compute_score("", item[SMN], item['reward_model']['ground_truth'], item['extra_info'])
        item['_score'] = res['score']
        item['_rank'] = res['rank']
        item['_more_returns'] = res['more_returns']
        return item
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # map will maintain order
        updated_items = list(executor.map(compute_single_score, prompt_list))
    
    return updated_items

# Filter
def filter_data(prompt_list):
    filtered_json_list = []
    for item in prompt_list:
        if filter_standard(item):
            filtered_json_list.append(item)
    return filtered_json_list

def process_file(input_path, output_path):
    prompt_list = load_json(input_path)
    # Get scores
    prompt_list = get_score(prompt_list)
    # Calculate average, highest, and lowest scores
    print("Calculating average, highest, and lowest scores")
    new_json_list = get_avg_score(prompt_list)
    # Filter
    if FILTER:
        print("Filtering")
        new_json_list = filter_data(new_json_list)
    else:
        print("Not filtered")
    print(f"Original data has {len(new_json_list)} entries, after filtering has {len(new_json_list)} entries, saved to {output_path}")
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in new_json_list:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')



if __name__ == '__main__':
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='VLLM Inference with Config File')
    parser.add_argument('--config', type=str, default='rldata_filter_topiocqa/config.yaml', 
                       help='Path to configuration file (default: config.yaml)')
    parser.add_argument('--model-name', type=str, default=None,
                       help='Model name to override config')
    parser.add_argument('--model-smn', type=str, default=None,
                       help='Model short name to override config')
    args = parser.parse_args()
    
    # Read configuration file
    # config_path = os.path.join(os.path.dirname(__file__), args.config)
    config_path = args.config
    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    
    # Command line arguments override configuration file
    if args.model_name is not None:
        config['model']['name'] = args.model_name
    if args.model_smn is not None:
        config['model']['smn'] = args.model_smn
    
    global ROLLOUT_N
    ROLLOUT_N = config['sample']['rollout_n']
    global SMN
    SMN = config['model']['smn']
    # This file's input is the output of the previous file
    input_dir = config['data']['output_dir']
    input_dir = os.path.join(input_dir, SMN)
    # But file names are all the same
    input_files = config['data']['input_files']
    # Save path is still put in different path
    output_dir = f"{input_dir}_filtered"
    os.makedirs(output_dir, exist_ok=True)
    # Process each input file
    for input_file in input_files:
        input_path = os.path.join(input_dir, input_file)
        # Change name
        output_path = os.path.join(output_dir, input_file)
        process_file(input_path, output_path)
        
        
        