import re
import torch
import pandas as pd
from transformers import AutoModel, AutoTokenizer
import argparse
import json
import ray
import os


def read_dataset(data_path):
    if data_path.endswith('.pkl'):
        dataset = pd.read_pickle(data_path)
    elif data_path.endswith('.parquet'):
        dataset = pd.read_parquet(data_path)
    elif data_path.endswith('.json'):
        dataset = pd.read_json(data_path)

    if type(dataset) == list:
        dataset = pd.DataFrame(dataset)

    return dataset


def log_message(message, file=None, append=True):
    """
    Prints a message to console and writes it to a file if provided.

    Args:
        message: The message to print and write
        file: File object to write to (optional)
        append: Whether to add a newline (default: True)
    """
    print(message)
    if file:
        file.write(message + ('\n' if append else ''))


def get_k_results(dataset, all_items_scores, k, output_file):
    log_message(
        f"======== Scores for the top {k} responses: =========", output_file)
    score = 0
    for i in range(len(all_items_scores)):
        reward_model_score = all_items_scores[i][:k]
        private_scores = dataset.iloc[i]['private_scores'][:k]
        if max(reward_model_score) == float('-inf'):
            score += sum(private_scores)/len(private_scores)
        else:
            max_index = reward_model_score.index(max(reward_model_score))
            if private_scores[max_index] == True:
                score += 1
            else:
                score += 0
    log_message(f"Score: {score/len(all_items_scores)}", output_file)

    private_score = 0
    for i in range(len(all_items_scores)):
        private_scores = max(dataset.iloc[i]['private_scores'][:k])
        private_score += private_scores
    log_message(
        f"Private score: {private_score/len(all_items_scores)}", output_file)


@ray.remote(num_gpus=1)
class RewardScorer:
    def __init__(self, model_name):
        # Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor. Inside the
        # actor process these GPUs are re-indexed starting from 0. Therefore we should
        # always use index 0 if CUDA is available.
        self.device_str = "cuda:0"

        # Load the new reward model
        self.rm = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map=self.device_str,
            trust_remote_code=True
        )

        self.rm_tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True)
        # Attributes for the reward model
        self.attributes = ['helpsteer-helpfulness', 'helpsteer-correctness', 'helpsteer-coherence',
                           'helpsteer-complexity', 'helpsteer-verbosity']
        print(f"RewardScorer initialized on {self.device_str}")

    def score_response(self, prompt: str, response_text: str):
        # Use the new model API to get the reward score for a single conversation
        messages = [
            {"role": "user",      "content": prompt},
            {"role": "assistant", "content": response_text}
        ]
        with torch.no_grad():
            score_value = self.rm.get_score(self.rm_tokenizer, messages)
        # Clear GPU cache
        if self.device_str.startswith('cuda'):
            torch.cuda.empty_cache()
        return float(score_value)


# Inserted functions to extract code and evaluate scores

def extract_code(meta_item):
    """
    Extract code string from a metadata item, supporting different metadata formats.
    Returns the code string or None if extraction fails.
    """
    try:
        return json.loads(meta_item['metadata'][0][0])['code']
    except Exception:
        try:
            return meta_item[1]['code']
        except Exception:
            return None


def evaluate_scores(dataset, all_items_scores, ks, strategy_name, output_file):
    log_message(f"==== {strategy_name} on Top-k ====", output_file)
    num_items = len(all_items_scores)
    for k in ks:
        # Reward accuracy
        correct = 0.0
        for i, scores in enumerate(all_items_scores):
            topk = scores[:k]
            private_flags = dataset.iloc[i]['private_scores'][:k]
            if max(topk) == float('-inf'):
                correct += sum(private_flags) / len(private_flags)
            else:
                max_idx = topk.index(max(topk))
                correct += float(private_flags[max_idx])
        reward_acc = correct / num_items
        # Private best
        private_best = sum(max(dataset.iloc[i]['private_scores'][:k]) for i in range(
            num_items)) / num_items
        log_message(
            f"-- Top {k} {strategy_name} Reward Accuracy: {reward_acc}", output_file)
        log_message(
            f"-- Top {k} {strategy_name} Private Best: {private_best}", output_file)


# Check old format and convert to new format


def convert_old_path_to_new(path):
    # Match old format's final_x.pkl and replace with x_final.pkl
    match = re.search(r'final_(\d+)\.pkl$', path)
    if match:
        idx = match.group(1)
        new_path = re.sub(r'final_(\d+)\.pkl$', f'{idx}_final.pkl', path)
        print(f"[INFO] Detected old format path. Converted to: {new_path}")
        return new_path
    return path


if __name__ == '__main__':
    # Load model and tokenizer
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str,
                        default="/path/to/file/eval/7B_useful/7B_adversarial_300_7B_adversarial_300_code_output_final.pkl")
    parser.add_argument('--num_workers', type=int, default=0,
                        help="Number of Ray workers. Defaults to number of GPUs if available, else 1.")

    args = parser.parse_args()

    args.data_path = convert_old_path_to_new(args.data_path)

    model_name = "/path/to/folder/data/internlm2-7b-reward"

    output_df_path = args.data_path.replace(
        ".pkl", f"_{model_name.split('/')[-1]}_reward_scores.pkl")

    if not os.path.exists(output_df_path):
        ray.init(ignore_reinit_error=True)
        num_workers = args.num_workers
        if num_workers <= 0:
            if torch.cuda.is_available():
                num_workers = torch.cuda.device_count()
            else:
                num_workers = 1

        print(f"Using {num_workers} Ray workers.")
        actors = [RewardScorer.remote(model_name) for _ in range(num_workers)]
        actor_pool_size = len(actors)
        current_actor_idx = 0

        dataset = read_dataset(args.data_path)
        all_unfiltered_scores = []  # Scores for all available codes
        all_public_scores = []      # Scores after public filtering

        for i in range(len(dataset)):
            prompt = dataset.iloc[i]['prompt'][0]['content']
            responses_metadata = dataset.iloc[i]['metadata']
            public_scores_flags = dataset.iloc[i]['public_scores']

            scores_all = [float('-inf')] * len(responses_metadata)
            tasks = []
            # Keep track of which original index 'j' each task corresponds to
            task_to_original_idx_map = []

            for j, meta in enumerate(responses_metadata):
                code_str = extract_code(meta)
                if code_str is not None:
                    actor = actors[current_actor_idx % actor_pool_size]
                    tasks.append(actor.score_response.remote(prompt, code_str))
                    task_to_original_idx_map.append(j)
                    current_actor_idx += 1

            if tasks:
                # Get results for the tasks launched for this item
                task_results = ray.get(tasks)
                for task_idx, score in enumerate(task_results):
                    j = task_to_original_idx_map[task_idx]
                    scores_all[j] = score
                    print(f"Score for response {i}-{j}: {score}")

            # Build public-filtered scores
            scores_public = [scores_all[j] if public_scores_flags[j] else float(
                '-inf') for j in range(len(scores_all))]
            all_unfiltered_scores.append(scores_all)
            all_public_scores.append(scores_public)

        # At this point, `all_items_scores` contains a list of lists,
        # where each inner list corresponds to the `reward_model_scores` for a dataset item.
        # You can now add this to the dataset DataFrame or save it as needed.
        dataset['reward_model_scores_ray_all'] = all_unfiltered_scores
        dataset['reward_model_scores_ray_public'] = all_public_scores
        dataset.to_pickle(output_df_path)
        print(f"DataFrame with Ray scores saved to {output_df_path}")
    else:
        dataset = pd.read_pickle(output_df_path)
        all_unfiltered_scores = dataset['reward_model_scores_ray_all'].tolist()
        all_public_scores = dataset['reward_model_scores_ray_public'].tolist()

    output_result_path = args.data_path.replace(
        ".pkl", f"_{model_name.split('/')[-1]}_reward_scores.txt")
    output_file = open(output_result_path, 'w')

    # Evaluate both scoring strategies
    evaluate_scores(dataset, all_public_scores, ks=[
                    1, 2, 4, 8, 16, 32], strategy_name="Public-Filtered", output_file=output_file)
    evaluate_scores(dataset, all_unfiltered_scores, ks=[
                    1, 2, 4, 8, 16, 32], strategy_name="Unfiltered", output_file=output_file)

    ray.shutdown()
