from collections import defaultdict
from verl import DataProto
from verl.utils.reward_score import _default_compute_score
import torch
import json

from verl.utils.reward_score.searchrl import compute_score_parallel

def remove_surrogates(s):
    if isinstance(s, str):
        return s.encode('utf-8', 'surrogatepass').decode('utf-8', 'ignore')
    return s

class SearchRLRewardManagerWithSaveParallel():
    def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source', save_path=None, max_resp_len=None, reward_cfg=None, is_validate=False) -> None:
        self.tokenizer = tokenizer
        self.num_examine = num_examine 
        self.compute_score = compute_score or _default_compute_score
        self.max_resp_len = max_resp_len
        self.overlong_buffer_cfg = reward_cfg.overlong_buffer
        self.reward_cfg = reward_cfg
        self.reward_fn_key = reward_fn_key
        self.save_path = save_path
        self.is_validate = is_validate
        if self.overlong_buffer_cfg is not None:
            assert self.max_resp_len is not None, f"max_resp_len must be provided if {self.overlong_buffer_cfg=}, but got None"

    def __call__(self, data: DataProto, curr_save_path=None, return_dict=False):
        """We will expand this function gradually based on the available datasets"""

        if curr_save_path is not None:
            save_path = curr_save_path
        else:
            save_path = self.save_path

        if 'rm_scores' in data.batch.keys():
            if return_dict:
                return {"reward_tensor": data.batch['rm_scores']}
            else:
                return data.batch['rm_scores']

        reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
        reward_extra_info = defaultdict(list)

        already_print_data_sources = {}

        if save_path is not None:
            save_file = open(save_path, 'a')

        data_sources = []
        solution_strs = []
        ground_truths = []
        questions = []
        qids = []
        valid_response_lengths = []
        question_decompositions = []
        model_generated_tokens = []
        model_generated_tokens_str = []
        sub_hypotheses = []

        for i in range(len(data)):
            data_item = data[i]

            prompt_ids = data_item.batch['prompts']

            prompt_length = prompt_ids.shape[-1]

            valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]

            response_ids = data_item.batch['responses']
            valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
            valid_response_lengths.append(valid_response_length)
            valid_response_ids = response_ids[:valid_response_length]

            sequences = torch.cat((valid_prompt_ids, valid_response_ids))
            sequences_str = self.tokenizer.decode(sequences)
            
            question = data_item.non_tensor_batch['question']
            question_decomposition = data_item.non_tensor_batch['extra_info']['question_decomposition']
            ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
            sub_hypothesis = data_item.non_tensor_batch['extra_info'].get('sub_hypothesis', None)
            qid = data_item.non_tensor_batch['extra_info']['id']
            
            loss_mask = data_item.batch['loss_mask']
            model_generated_tokens.append(valid_response_ids[loss_mask[:valid_response_length]==1].tolist())
            model_generated_tokens_str.append(self.tokenizer.decode(model_generated_tokens[-1]))

            data_source = data_item.non_tensor_batch[self.reward_fn_key]

            data_sources.append(data_source)
            solution_strs.append(sequences_str)
            ground_truths.append(ground_truth)
            questions.append(question)
            qids.append(qid)
            question_decompositions.append(question_decomposition)
            sub_hypotheses.append(sub_hypothesis)

        scores = compute_score_parallel(
            eos_token=self.tokenizer.eos_token,
            qids=qids,
            solution_strs=solution_strs,
            ground_truths=ground_truths,
            questions=questions,
            question_decompositions=question_decompositions,
            reward_cfg=self.reward_cfg,
            valid_response_lengths=valid_response_lengths,
            max_resp_len=self.max_resp_len,
            is_validate=self.is_validate,
            model_generated_tokens=model_generated_tokens,
            model_generated_tokens_str=model_generated_tokens_str,
            optimal_ref=self._optimal_ref,
            sub_hypotheses=sub_hypotheses
        )
        
        for i in range(len(data)):
            data_item = data[i]
            valid_response_length = valid_response_lengths[i]
            score_dict = scores[i]
            data_source = data_sources[i]
            sequences_str = solution_strs[i]
            ground_truth = ground_truths[i]
            question_decomposition = question_decompositions[i]
            sub_hypothesis = sub_hypotheses[i]
            qid = qids[i]
            score = score_dict['score']
            reason = score_dict['reason']
            score_breakdown = score_dict['score_breakdown']
            
            if self.overlong_buffer_cfg.enable:
                overlong_buffer_len = self.overlong_buffer_cfg.len
                expected_len = self.max_resp_len - overlong_buffer_len
                exceed_len = valid_response_length - expected_len
                overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
                overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
                score += overlong_reward
                score_breakdown['overlong_reward'] = overlong_reward
                if self.overlong_buffer_cfg.log:
                    reward_extra_info["overlong_reward"].append(overlong_reward)
                    reward_extra_info["overlong"].append(overlong_reward < 0)
                
            reward_tensor[i, valid_response_length - 1] = score

            if save_path is not None:
                save_json_line = {
                    'id': qid,
                    'data_source': data_source,
                    'sequences_str': remove_surrogates(sequences_str),
                    'ground_truth': ground_truth,
                    'score': float(score),
                    'reason': remove_surrogates(reason),
                    'overlong_reward': float(overlong_reward) if self.overlong_buffer_cfg.enable else 0,
                    'overlong': bool(overlong_reward < 0) if self.overlong_buffer_cfg.enable else False,
                    'question_decomposition': question_decomposition,
                    'score_breakdown': score_breakdown,
                    'sub_hypothesis': sub_hypothesis,
                }
                save_file.write(json.dumps(save_json_line, ensure_ascii=False) + '\n')

            if data_source not in already_print_data_sources:
                already_print_data_sources[data_source] = 0

            if already_print_data_sources[data_source] < self.num_examine:
                already_print_data_sources[data_source] += 1
                print('-' * 20)
                print(f"data_source: \n{data_source}")
                print(f"sequences_str: \n{sequences_str}")
                print(f"ground_truth: \n{ground_truth}")
                print(f"score: \n{score}")  
                print(f"reason: \n{reason}")
                print(f"score_breakdown: \n{score_breakdown}")
                print('-' * 20)

        if save_path is not None:
            save_file.close()
        
        assistant_content_analysis = []
        for idx, sequences_str in enumerate(solution_strs):
            score_dict = scores[idx]
            multiple_answer_count = score_dict['multiple_answer_count']
            f1_score = score_dict['score_breakdown'].get('f1_score', 0)
            precision = score_dict['score_breakdown'].get('precision', 0)
            recall = score_dict['score_breakdown'].get('recall', 0)
            analysis_results = {}
            analysis_results['multiple_answer_count'] = multiple_answer_count
            analysis_results['f1_score'] = f1_score
            analysis_results['precision'] = precision
            analysis_results['recall'] = recall
            assistant_content_analysis.append(analysis_results)
        
        if return_dict:
            return_data = {
                "reward_tensor": reward_tensor,
                "reward_extra_info": reward_extra_info,
            }
            
            if assistant_content_analysis:
                think_counts = [result['think_count'] for result in assistant_content_analysis]
                search_counts = [result['search_count'] for result in assistant_content_analysis]
                multiple_answer_counts = [result['multiple_answer_count'] for result in assistant_content_analysis]
                f1_scores = [result['f1_score'] for result in assistant_content_analysis]
                precisions = [result['precision'] for result in assistant_content_analysis]
                recalls = [result['recall'] for result in assistant_content_analysis]
                data = {
                    "response/think_count_max": max(think_counts) if think_counts else 0,
                    "response/think_count_mean": sum(think_counts) / len(think_counts) if think_counts else 0,
                    "response/think_count_min": min(think_counts) if think_counts else 0,
                    "response/search_count_max": max(search_counts) if search_counts else 0,
                    "response/search_count_mean": sum(search_counts) / len(search_counts) if search_counts else 0,
                    "response/search_count_min": min(search_counts) if search_counts else 0,
                    "response/multiple_answer_count_max": max(multiple_answer_counts) if multiple_answer_counts else 0,
                    "response/multiple_answer_count_mean": sum(multiple_answer_counts) / len(multiple_answer_counts) if multiple_answer_counts else 0,
                    "response/multiple_answer_count_min": min(multiple_answer_counts) if multiple_answer_counts else 0,
                }
                if self.is_validate:
                    data["val_f1/f1_score_mean"] = sum(f1_scores) / len(f1_scores) if f1_scores else 0
                    data["val_f1/precision_mean"] = sum(precisions) / len(precisions) if precisions else 0
                    data["val_f1/recall_mean"] = sum(recalls) / len(recalls) if recalls else 0
                else:
                    data["train_f1/f1_score_mean"] = sum(f1_scores) / len(f1_scores) if f1_scores else 0
                    data["train_f1/precision_mean"] = sum(precisions) / len(precisions) if precisions else 0
                    data["train_f1/recall_mean"] = sum(recalls) / len(recalls) if recalls else 0
                print(data)
                return_data["response_analysis"] = data
            
            return return_data
        else:
            return reward_tensor