import argparse
# import jsonlines
import json
import os
# import multiprocessing
from re import A
import sys
sys.path.append('..')
from tqdm import tqdm
from termcolor import colored
import numpy as np
from itertools import combinations
import numpy as np

# MAX_PROCESSES = 16

def calculate_group_accuracy_matrix(reward_score, is_correct):
  
    if not reward_score or not is_correct:
        return []
    
    reward_arr = np.array(reward_score)  
    correct_arr = np.array(is_correct)   
    
    num_questions = reward_arr.shape[1]
    acc = []
    
    for q in range(num_questions):
      
        rewards = reward_arr[:, q]      
        correct = correct_arr[:, q]     
        
        if np.all(correct) or np.all(~correct):
            acc.append(-1)
            continue
        
        correct_rewards = rewards[correct]   
        incorrect_rewards = rewards[~correct]
        
        comparisons = correct_rewards[:, None] > incorrect_rewards[None, :]
        total_comparisons = comparisons.size
        correct_comparisons = np.sum(comparisons)
        acc.append(float(correct_comparisons / total_comparisons))
    
    return acc

def cal_average_acc(res):
    
    final=[x["reward_group_acc"] for x in res if x["reward_group_acc"]!=-1]
    return len(final),sum(final)/len(final)

def cal_average_acc2(res):
    
    final=[x for x in res if x!=-1]
    return len(final),sum(final)/len(final)

def get_combinations(lst, k):
    return [list(comb) for comb in combinations(lst, k)]

def cal_weight_sum(comb,chosen,args):
    single_acc=[]
    for co in comb:
        js=json.load(open(os.path.join(args.LLM_first_single_rm_acc_dir,args.rm_model_list[co],'sum_complex.json')))
        single_acc.append(js['average_acc'])
    weight=[]
    if args.weight_method=='linear':
        total = sum(single_acc)
        weight = [acc / total for acc in single_acc]
    elif args.weight_method=='softem':
        scaled_acc = np.array(single_acc) / args.temperature
        # exp_acc = np.exp(scaled_acc - np.max(scaled_acc))  # 数值稳定性优化
        exp_acc = np.exp(scaled_acc)  
        weight = exp_acc / np.sum(exp_acc)
        weight=weight.tolist()
    for i in range(len(weight)):
        chosen[i,:,:]=chosen[i,:,:]*weight[i]
    return np.sum(chosen,axis=0).tolist()

def build_reward_comb(args):
    rm_acc_bank_dir=''
    if '8k' in args.rm_question_bank_dir:
        rm_acc_bank_dir=os.path.join(os.path.dirname(os.path.abspath(args.rm_question_bank_dir)),'rm_acc_bank_comb_8k')
    else:
        rm_acc_bank_dir=os.path.join(os.path.dirname(os.path.abspath(args.rm_question_bank_dir)),'rm_acc_comb_bank')
    
    if not os.path.exists(rm_acc_bank_dir):
        os.makedirs(rm_acc_bank_dir, exist_ok=True)
    comb_num_list=[x for x in range(len(args.rm_model_list))]
    comb_lis=get_combinations(comb_num_list,args.comb_num)
    all_reward=[]
    all_correct=[]

    flag=False
    for rm in args.rm_model_list:
        rm_re=[]
        for model in args.model_list:
            js_path=os.path.join(args.rm_question_bank_dir,model,rm,f'result_{args.prompt_mode}.json')
            f=open(js_path,'r')
            js=json.load(f)
            f.close()
            rm_re.append([x['reward_score'] for x in js])
            if not flag:
                all_correct.append([x['is_correct'] for x in js])
        if len(all_correct)>1:
            flag=True
        all_reward.append(rm_re)
    all_reward=np.array(all_reward)
    
    for comb in tqdm(comb_lis):
        chosen=all_reward[comb,:,:]
        if args.weighted:
            resl=cal_weight_sum(comb,chosen,args)
        else:
            resl=np.sum(chosen,axis=0).tolist()
        res_group_acc=calculate_group_accuracy_matrix(resl,all_correct)
        valid,average=cal_average_acc2(res_group_acc)
        cn=[args.rm_model_list[i] for i in comb]
        comb_name='_'.join(cn)
        if args.weighted and args.weight_method=='linear':
            comb_dir=os.path.join(rm_acc_bank_dir,comb_name+'_'+args.weight_method)
        elif args.weighted and args.weight_method=='softem':
            comb_dir=os.path.join(rm_acc_bank_dir,comb_name+'_'+args.weight_method+str(args.temperature))
        else:
            comb_dir=os.path.join(rm_acc_bank_dir,comb_name+'_sum')
        if not os.path.exists(comb_dir):
            os.makedirs(comb_dir, exist_ok=True)
        final_res_dir=os.path.join(comb_dir,'result_'+args.prompt_mode+'.json')
        sum_res_dir=os.path.join(comb_dir,'sum_'+args.prompt_mode+'.json')
        f1=open(final_res_dir,'w')
        json.dump(res_group_acc,f1)
        f1.close()
        f2=open(sum_res_dir,'w')
        re={'valid_count':valid,'average_acc':average}
        json.dump(re,f2)
        f2.close()
        # print(res_group_acc)
        
def build_reward_bank(args):
    rm_acc_bank_dir=''
    if '8k' in args.rm_question_bank_dir:
        rm_acc_bank_dir=os.path.join(os.path.dirname(os.path.abspath(args.rm_question_bank_dir)),'LLM_first_rm_acc_bank_8k')
    else:
        rm_acc_bank_dir=os.path.join(os.path.dirname(os.path.abspath(args.rm_question_bank_dir)),'LLM_first_rm_acc_bank')
    
    if not os.path.exists(rm_acc_bank_dir):
        os.makedirs(rm_acc_bank_dir, exist_ok=True)


    for rm in tqdm(args.rm_model_list):
        reward_score=[]
        is_correct=[]
        js=None
        print(colored('Begin to build '+rm, 'green'))
        rm_dir=os.path.join(rm_acc_bank_dir,rm)
        if not os.path.exists(rm_dir):
            os.makedirs(rm_dir, exist_ok=True)
        for model in args.model_list:
            model_bank_dir=os.path.join(args.rm_question_bank_dir,model,rm)
            js_files=os.listdir(model_bank_dir)
            model_bank_js=''
            if len(js_files)==1:
                model_bank_js=os.path.join(model_bank_dir,js_files[0])
            else:
                for file in js_files:
                    if args.prompt_mode in file:
                        model_bank_js=os.path.join(model_bank_dir,file)
            f=open(model_bank_js)
            js=json.load(f)
            reward_score.append([sc["reward_score"] for sc in js])  
            is_correct.append([co["is_correct"] for co in js])  
            f.close()
        
        res_group_acc=calculate_group_accuracy_matrix(reward_score,is_correct)
        res=[]
        for idx, acc in enumerate(res_group_acc):
            qu_i={
                        "question_id":idx,
                        "question":js[idx]['question'],
                        "reward_group_acc":acc
                    }
            res.append(qu_i)
        final_res_dir=os.path.join(rm_dir,'result_'+args.prompt_mode+'.json')
        f2=open(final_res_dir,'w')
        json.dump(res,f2)
        f2.close()
        valid,average=cal_average_acc(res)
        final_sum_dir=os.path.join(rm_dir,'sum_'+args.prompt_mode+'.json')
        f2=open(final_sum_dir,'w')
        re={'valid_count':valid,'average_acc':average}
        json.dump(re,f2)
        f2.close()


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
    "--model_list",
    type=list,
    # default=['Qwen2.5-72b-Instruct', 'R1-distill-llama70b', 'QwQ-32B', 'Qwen2.5-32b-Instruct']
    # default=['TeleChat2-35B-32K','Qwen2.5-32b-Instruct','Qwen2.5-Coder-32b-Instruct', 'Qwen2.5-72b-Instruct', 'Meta-Llama-3.3-70B-Instruct', 'QwQ-32B',
    #                'gemma_3_27b_it']
    default=['Qwen2.5-32b-Instruct', 'Meta-Llama-3.3-70B-Instruct', 'HuatuoGPT-o1-72B',
                'Llama-3_3-Nemotron-Super-49B-v1', 'internlm2_5-20b-chat', 'R1-distill-llama32b','Qwen2.5-72b-Instruct', 
        'Qwen2.5-Coder-32b-Instruct', 'QwQ-32B', 'Qwen3-32B', 'GLM-Z1-32B-0414',
                'R1-distill-llama70b', 'TeleChat2-35B-32K', 'EXAONE-Deep-32B', 'gemma_3_27b_it'
                ]
    )
    parser.add_argument( # pay attention to the GPU resource in case not enough to run all the RMs in a single path.
    "--rm_model_list",
    type=list,
    # default=['Qwen2.5-Math-PRM-7B',"Qwen2.5-Math-RM-72B-VLLM"]
    default=['AceCodeRM-32B','Qwen2.5-Math-PRM-7B',"Qwen2.5-Math-RM-72B-VLLM","Skywork-27B-Reward-Models-VLLM","INF-ORM-Llama3.1-70B-VLLM","LDL-Reward-Gemma-2-27B-VLLM","QRM-Gemma-2-27B-VLLM","Skywork-Reward-V2-Llama-3.1-8B-40M"]    

    # default=['QRM-Gemma-2-27B-VLLM']
    )
    parser.add_argument(
    "--rm_question_bank_dir",
    type=str,
    default='./result/LLM_first_rm_single_agent_8k'
    )
    parser.add_argument(
    "--LLM_first_single_rm_acc_dir",
    type=str,
    default='./result/LLM_first_rm_acc_bank_8k'
    )
    parser.add_argument(
    "--comb_num",
    type=int,
    default=2
    )
    parser.add_argument(
    "--temperature",
    type=float,
    default=1
    )
    parser.add_argument(
    "--comb",
    action='store_true',
    )
    parser.add_argument(
    "--weighted",
    action='store_true',
    )
    parser.add_argument(
    "--weight_method",
    type=str,
    default='softem',
    help='linear, softmax with temperature=softem'
    )
    
    parser.add_argument(
    "--prompt_mode",
    type=str,
    default='complex',
    help='complex or simple, may support more in the future...'
    )
    args = parser.parse_args()
    if args.comb:
        build_reward_comb(args)
    else:
        build_reward_bank(args)