import pandas as pd
import numpy as np
import json
import random
import math
import rbo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset, load_from_disk
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, AutoConfig, set_seed
from peft import PeftModel
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training, PeftConfig
import torch
from bon_eval_utils import eval_gsm8k, eval_math_prm
import sys, os
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
import re
import deepspeed
from copy import deepcopy
from safetensors import safe_open
from collections import Counter
from scipy.stats import kendalltau, spearmanr

sys.path.append('../joint_train')
from modeling_custom_qwen import JointPRMModel, load_saved_joint_prm_model
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--baseline", type=int, default=0)
parser.add_argument("--combine", type=int, default=0)
parser.add_argument("--orm", type=int, default=0)
parser.add_argument("--backbone-path", type=str, default="/Qwen2.5-Math-7B-Instruct")
parser.add_argument("--bias-expert-path", type=str, default="/qwen0.5B")
parser.add_argument("--joint-model-path", type=str, required=True, help="Path to saved joint model directory")
parser.add_argument("--save-file", type=str, default="./prm-data.json")
parser.add_argument("--c-param", type=float, default=0.5, help="Parameter c in the joint reward formula")
args = parser.parse_args()
print(args)

def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def instruction_format(s):
    return f"Below is an instruction that describes a task.\nWrite a response that appropriately completes the request.\n\n### Instruction:\n{s}\n\n### Response: Let's think step by step"

def split_query(completions, n, N=16):
    splitted_completions = []
    for idx in range(int(len(completions) / N)):
        samples = [sample for sample in completions if sample["idx"] == idx]
        samples = sorted(samples, key=lambda x: x["logprobs"], reverse=True)
        splitted_completions.append(samples[:n])
    return splitted_completions

def best_of_n(splitted_completions):
    selected_completions = []
    for n_completions_per_query in splitted_completions:
        n_completions_per_query = sorted(n_completions_per_query, key=lambda x: x["reward"], reverse=True)
        assert all([n_completions_per_query[0]["reward"] >= completion["reward"] for completion in n_completions_per_query])
        selected_completions.append(n_completions_per_query[0])
    return selected_completions

def compute_metrics(dataset_name, scored_results):
    metrics = {}
    if dataset_name == 'gsm8k':
        original_dataset = load_dataset('qintongli/GSM-Plus')['testmini']
    else:
        path = '/mnt/workspace/jiachenzhu/data/MATH500.jsonl'
        with open(path) as f:
            original_dataset = [json.loads(line) for line in f]
    if not args.baseline and not args.combine:
        assert len(original_dataset) == len(scored_results)
        if dataset_name == 'math':
            acc, _, _ = eval_math_prm([{'response': query['response']} for query in scored_results],
                                      all_problems=[{'solution': data['solution'], 'question': data['problem']} for data in original_dataset], is_extract=False)
        else:
            acc, _, _ = eval_gsm8k([{'response': query['response']} for query in scored_results],
                                   answers=[data['answer'] for data in original_dataset], is_extract=True)
        n = 8
        metrics[n] = acc
        if accelerator.is_local_main_process:
            print('*********')
            print(n, acc)
            print('*********')
    return metrics

alpha=0.0001
MODEL_PATH=args.joint_model_path
MODEL_PATH=MODEL_PATH.replace("/trained_models/","")
MODEL_PATH=MODEL_PATH.replace("/checkpoint-1578","")
MODEL_PATH=MODEL_PATH.replace("/checkpoint-3207","")
RESULT_FILE = f"/{MODEL_PATH}_alpha{alpha}.jsonl"
NPY_SAVE_PATH = "/joint_model"

def compute_joint_reward(prm_reward, bias_score, c_param=0.5):
    joint_reward = (prm_reward - c_param) * bias_score
    return joint_reward

def test_joint_model(data_name, data_file, c_param=0.5):
    if data_name == 'gsm8k':
        file_list = [data_file]
        queries = []
        cur_queries = []
        origin_dataset = load_dataset('qintongli/GSM-Plus')['testmini']
        for file_name in file_list:
            cur_data = json.load(open(file_name))
            if len(cur_queries) == len(cur_data):
                for cur_q, cur_d in zip(cur_queries, cur_data):
                    cur_q['responses'].extend(cur_d['responses'])
            else:
                cur_queries = deepcopy(cur_data)
        assert len(origin_dataset) == len(cur_queries), (len(origin_dataset), len(queries))
        for idx, (data, ori) in enumerate(zip(cur_queries, origin_dataset)):
            assert data['question'] == ori['question']
            assert len(data['responses']) == 128
            for response_dict in data['responses']:
                queries.append({
                    'idx': idx,
                    'prompt': data['question'],
                    'response': response_dict['text'],
                    'solution': ori['answer'],
                    'logprobs': 0,
                })
    elif data_name == 'math':
        file_list = [data_file]
        queries = []
        cur_queries = []
        path = '/MATH500.jsonl'
        with open(path) as f:
            origin_dataset = [json.loads(line) for line in f]
        for file_name in file_list:
            cur_data = json.load(open(file_name))
            if len(cur_queries) == len(cur_data):
                for cur_q, cur_d in zip(cur_queries, cur_data):
                    cur_q['responses'].extend(cur_d['responses'])
            else:
                cur_queries = deepcopy(cur_data)
        assert len(origin_dataset) == len(cur_queries), (len(origin_dataset), len(queries))
        for idx, (data, ori) in enumerate(zip(cur_queries, origin_dataset)):
            assert data['question'] == ori['problem']
            assert len(data['responses']) % 128 == 0
            for response_dict in data['responses']:
                queries.append({
                    'idx': idx,
                    'prompt': data['question'],
                    'response': response_dict['text'],
                    'solution': ori['solution'],
                    'logprobs': 0,
                })
    
    selected_completions_min = []
    selected_completions_mean = []
    selected_completions_last = []
    
    if not args.baseline:
        for idx, data in enumerate(queries):
            data['reward_idx'] = idx
            data["query"] = instruction_format(data["prompt"])
            steps = re.split('Step \d+:', data['response'])
            steps = [f'Step {id + 1}: ' + step.strip() for id, step in enumerate(steps) if step.strip()!='']
            data["answer"] = f" {prm_token}".join(steps) + f" {prm_token}"

        dataset = Dataset.from_pandas(pd.DataFrame.from_records(queries))
        dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=data_collator)
        
        good_token = '+'
        bad_token = '-'
        candidate_tokens = tokenizer.encode(f" {good_token} {bad_token}")
        question_num = 0
        reward_list_min = []
        reward_list_mean = []
        reward_list_last = []
        answer_num = 0
        stop_flag = False
        
        length_list = []
        length_min = []
        length_mean = []
        length_last = []
        
        out_reward_list = []
        out_reward_list_min = []
        out_reward_list_mean = []
        out_reward_list_last = []
        
        kendalltau_list_min = []
        kendalltau_list_mean = []
        kendalltau_list_last = []
        
        spearman_list_min = []
        spearman_list_mean = []
        spearman_list_last = []
        
        reward_diff_list_min = []
        reward_diff_list_mean = []
        reward_diff_list_last = []

        rbo_list_min = []
        rbo_list_mean = []
        rbo_list_last = []  

        for inputs in tqdm(dataloader):
            answer_num += inputs['input_ids'].shape[0]
            
            if stop_flag and answer_num < 128:
                continue
            
            with torch.no_grad():
                if hasattr(model, 'prm_model') and hasattr(model, 'bias_expert_model'):
                    prm_outputs = model.prm_model(
                        input_ids=inputs['input_ids'], 
                        attention_mask=inputs['attention_mask']
                    )
                    bias_outputs = model.bias_expert_model(
                        input_ids=inputs['input_ids'], 
                        attention_mask=inputs['attention_mask']
                    )
                    
                    prm_logits = prm_outputs.logits[:, :, candidate_tokens]
                    prm_scores = prm_logits.softmax(dim=-1)[:, :, 0]

                    bias_logits = bias_outputs.logits[:, :, candidate_tokens]
                   
                    bias_raw_scores = bias_logits.softmax(dim=-1)[:, :, 0]
                    
                else:

                    L = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
                    logits = L[:, :, candidate_tokens]
                    prm_scores = logits.softmax(dim=-1)[:, :, 0]
                    bias_raw_scores = torch.zeros_like(prm_scores) 

            cur_index = torch.where(inputs['special_tokens'] == -100, 0, inputs['special_tokens'])
            cur_index = cur_index.tolist()
        
            for i in range(len(cur_index)):
                temp_list = [x for x in cur_index[i] if x != 0]
                prm_reward = prm_scores[i, temp_list]
                bias_score = bias_raw_scores[i, temp_list]
                

                joint_reward = compute_joint_reward(prm_reward, bias_score, c_param)
                joint_reward=joint_reward.to(torch.float32).cpu().numpy().tolist()
                
                joint_reward=[x-alpha*temp_list[idx] for idx,x in enumerate(joint_reward)]
                out_reward_list.append(joint_reward)

                min_reward = np.min(joint_reward)
                reward_list_min.append(min_reward)

                avg_reward = np.mean(joint_reward)
                reward_list_mean.append(avg_reward)
                
                last_reward = joint_reward[-1]
                reward_list_last.append(last_reward)
                
                length_list.append(temp_list)
            
            n = 16
            def compute_kendalltau(reward_list):
                front8 = reward_list[:8]
                last8 = reward_list[-8:]
                rank_front8 = np.argsort(front8)
                rank_last8 = np.argsort(last8)
                kendalltau_value, _ = kendalltau(rank_front8, rank_last8)
                return kendalltau_value
            
            def compute_Spearman(reward_list):
                front8 = reward_list[:8]
                last8 = reward_list[-8:]
                rank_front8 = np.argsort(front8)
                rank_last8 = np.argsort(last8)  
                spearman_value, _ = spearmanr(rank_front8, rank_last8)
                return spearman_value
            
            def get_reward_diff(reward_list):
                front8 = reward_list[:8]
                last8 = reward_list[-8:]
                diff = np.array(front8) - np.array(last8)
                return np.mean(diff)
            
            def compute_rbo(reward_list):
                front8 = reward_list[:8]
                last8 = reward_list[-8:]
                rank_front8 = np.argsort(front8)
                rank_last8 = np.argsort(last8) 
                rbo_value = rbo.RankingSimilarity(rank_front8, rank_last8).rbo(p=0.85)
                return rbo_value
            
            if answer_num == n:
                max_reward_index = reward_list_min.index(max(reward_list_min))
                select_length = length_list[max_reward_index]
                length_min += select_length
                out_reward_list_min += out_reward_list[max_reward_index]
                selected_completions_min.append(queries[question_num * 128 + max_reward_index])
                
                max_reward_index = reward_list_mean.index(max(reward_list_mean))
                select_length = length_list[max_reward_index]
                length_mean += select_length
                out_reward_list_mean += out_reward_list[max_reward_index]
                selected_completions_mean.append(queries[question_num * 128 + max_reward_index])
                
                max_reward_index = reward_list_last.index(max(reward_list_last))
                select_length = length_list[max_reward_index]
                length_last += select_length
                out_reward_list_last += out_reward_list[max_reward_index]
                selected_completions_last.append(queries[question_num * 128 + max_reward_index])
                
                kendalltau_list_min.append(compute_kendalltau(reward_list_min))
                kendalltau_list_mean.append(compute_kendalltau(reward_list_mean))
                kendalltau_list_last.append(compute_kendalltau(reward_list_last))
                
                spearman_list_min.append(compute_Spearman(reward_list_min))
                spearman_list_mean.append(compute_Spearman(reward_list_mean))
                spearman_list_last.append(compute_Spearman(reward_list_last))
                
                rbo_list_min.append(compute_rbo(reward_list_min))
                rbo_list_mean.append(compute_rbo(reward_list_mean))
                rbo_list_last.append(compute_rbo(reward_list_last))
                
                reward_diff_list_min.append(get_reward_diff(reward_list_min))
                reward_diff_list_mean.append(get_reward_diff(reward_list_mean))
                reward_diff_list_last.append(get_reward_diff(reward_list_last))
                
                reward_list_min = []
                reward_list_mean = []
                reward_list_last = []
                length_list = []
                out_reward_list = []
                question_num += 1
                stop_flag = True
            if answer_num == 128:
                answer_num = 0
                reward_list_min = []
                reward_list_mean = []
                reward_list_last = []
                length_list = []
                out_reward_list = []
                stop_flag = False
                
        
                
    data_set = data_file.replace('.json', '')
    data_set = data_set.replace('eval_data/', '')        
            
    m = compute_metrics(data_name, selected_completions_min)
    print(m)
    m['type'] = 'min'
    m['kendalltau'] = float(np.mean(kendalltau_list_min))
    m['spearman'] = float(np.mean(spearman_list_min))
    m['rbo'] = float(np.mean(rbo_list_min))
    m['reward_diff'] = float(np.mean(reward_diff_list_min))
    m['avg_length'] = float(np.mean(length_min))
    m['data_file'] = data_file
    m['c_param'] = c_param
    m['model_type'] = 'joint_prm_bias'
    
    with open(RESULT_FILE, 'a') as f:
        f.write(json.dumps(m) + '\n')
        
    m = compute_metrics(data_name, selected_completions_mean)
    print(m)
    m['type'] = 'mean'
    m['kendalltau'] = float(np.mean(kendalltau_list_mean))
    m['spearman'] = float(np.mean(spearman_list_mean))
    m['rbo'] = float(np.mean(rbo_list_mean))
    m['reward_diff'] = float(np.mean(reward_diff_list_mean))
    m['avg_length'] = float(np.mean(length_mean))
    m['data_file'] = data_file
    m['c_param'] = c_param
    m['model_type'] = 'joint_prm_bias'
    
    with open(RESULT_FILE, 'a') as f:
        f.write(json.dumps(m) + '\n')
        
    m = compute_metrics(data_name, selected_completions_last)
    print(m)
    m['type'] = 'last'
    m['kendalltau'] = float(np.mean(kendalltau_list_last))
    m['spearman'] = float(np.mean(spearman_list_last))
    m['rbo'] = float(np.mean(rbo_list_last))
    m['reward_diff'] = float(np.mean(reward_diff_list_last))
    m['avg_length'] = float(np.mean(length_last))
    m['data_file'] = data_file
    m['c_param'] = c_param
    m['model_type'] = 'joint_prm_bias'

    with open(RESULT_FILE, 'a') as f:
        f.write(json.dumps(m) + '\n')

if __name__ == '__main__':
    seed_everything(0)
    accelerator = Accelerator()
    
    if not args.baseline:
        prm_token = '\n\n\n\n\n'
        tokenizer = AutoTokenizer.from_pretrained(args.backbone_path)

        print(f"Loading joint model from: {args.joint_model_path}")
        model = load_saved_joint_prm_model(
            saved_model_path=args.joint_model_path,
            prm_base_model_path=args.backbone_path,
            bias_expert_base_model_path=args.bias_expert_path,
            USE_8bit=False,
            device_map='auto'
        )
        
        prm_token_id = tokenizer.encode(f" {prm_token}")[-1]

        ds_engine = deepspeed.init_inference(model,
                                             tensor_parallel={"tp_size": 1},
                                             dtype=torch.bfloat16)
        model = ds_engine.module
        model.eval()
        
        def data_collator(example, tokenizer=tokenizer):
            inputs = []
            special_ids = []
            orm_ids = []
            idx, reward_idx = [], []
            template = '{query}\n{answer}'
            for d in example:
                input_ids = tokenizer.encode(template.format(query=d['query'], answer=d['answer']),
                                              add_special_tokens=False)
                inputs.append(torch.tensor(input_ids))
                cur_special_ids = []
                for ii, id in enumerate(input_ids):
                    if id == prm_token_id:
                        cur_special_ids.append(ii - 1)
                special_ids.append(torch.tensor(cur_special_ids))
                orm_ids.append(cur_special_ids[-1])
                idx.append(d['idx'])
                reward_idx.append(d['reward_idx'])
            inputs = pad_sequence(inputs, padding_value=tokenizer.pad_token_id, batch_first=True)
            attention_mask = (inputs != tokenizer.pad_token_id)
            special_ids = pad_sequence(special_ids, padding_value=-100, batch_first=True)
            return {
                'input_ids': inputs.int().to(accelerator.device),
                'attention_mask': attention_mask.int().to(accelerator.device),
                'special_tokens': special_ids.to(accelerator.device),
                'orm_tokens': torch.tensor(orm_ids).to(accelerator.device),
                'idx': torch.tensor(idx).to(accelerator.device),
                'reward_idx': torch.tensor(reward_idx).to(accelerator.device)
            }
            
        data_names = ['math', 'gsm8k']
        
        for data_name in data_names:
            if data_name == 'math':
                data_files = ['/math-llama3-70b-inst-128_double.json',
                             '/math-muggle-128_double.json',
                             '/math-metamath-mistral-128_double.json']
            else:
                data_files = ['/gsm8k-plus-llama3-70b-inst-128_double.json',
                             '/gsm8k-plus-muggle-128_double.json',
                             '/gsm8k-plus-metamath-mistral-128_double.json']
                             
            for data_file in data_files:
                print(f"Testing {data_name} with {data_file}, c_param={args.c_param}")
                test_joint_model(data_name, data_file, c_param=args.c_param) 
