# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""
import json
from verl.utils.reward_score.livecodebench.compute_score import compute_score as compute_score_yr
from verl.workers.reward_manager.yr_code import parallel_compute_score
import pandas as pd
import argparse
import os
from tabulate import tabulate
from functools import partial
import multiprocessing
from openai import OpenAI
from rlhf_utils.online_server import multithread_openai_chat_completions_call
import random
from tqdm import tqdm


class SampleTracker:
    """Track the processing status and attempt counts of samples"""

    def __init__(self, dataset, max_attempts=5, output_dir=None):
        self.dataset = dataset
        self.max_attempts = max_attempts
        # Store sample information: {index: {'attempts': 0, 'status': 'pending'}}
        self.samples = {i: {'attempts': 0, 'status': 'pending'}
                        for i in range(len(dataset))}
        self.pending_count = len(dataset)
        self.success_count = 0
        self.failure_count = 0

        # Check existing successful files
        if output_dir:
            self._check_existing_files(output_dir)

    def _check_existing_files(self, output_dir):
        """Check existing files in the directory, skip samples that have been successfully processed"""
        existing_count = 0
        for idx, info in self.samples.items():
            try:
                question_id = self.dataset.iloc[idx]['reward_model']['ground_truth']['question_id']
                solution_file = os.path.join(output_dir, f"{question_id}.json")

                if os.path.exists(solution_file):
                    # Mark as successful
                    self.samples[idx]['status'] = 'success'
                    existing_count += 1
            except (KeyError, IndexError, TypeError):
                # If getting question_id fails, continue to the next
                continue

        # Update counts
        self.success_count = existing_count
        self.pending_count -= existing_count
        print(
            f"Found {existing_count} existing successful solutions, will skip processing.")

    def get_batch(self, batch_size):
        """Get a batch of samples to be processed"""
        pending_indices = [idx for idx, info in self.samples.items()
                           if info['status'] == 'pending']

        if not pending_indices:
            return []  # No more samples to be processed

        # Select up to batch_size pending samples
        selected = random.sample(pending_indices, min(
            batch_size, len(pending_indices)))
        return selected

    def update_status(self, idx, succeeded):
        """Update sample status"""
        if idx not in self.samples:
            return

        if succeeded:
            self.samples[idx]['status'] = 'success'
            self.success_count += 1
            self.pending_count -= 1
            return True
        else:
            self.samples[idx]['attempts'] += 1
            if self.samples[idx]['attempts'] >= self.max_attempts:
                self.samples[idx]['status'] = 'failed'
                self.failure_count += 1
                self.pending_count -= 1
                return False
            return None  # Continue trying

    def get_stats(self):
        """Get processing statistics"""
        return {
            'pending': self.pending_count,
            'success': self.success_count,
            'failure': self.failure_count,
            'total': len(self.samples)
        }

    def is_finished(self):
        """Check if all samples have been processed"""
        return self.pending_count == 0


def main():
    parser = argparse.ArgumentParser(
        description='Generate solutions for a dataset using batch processing')
    parser.add_argument('--data_path', type=str,
                        default='/path/to/folder/data/livecodebench_2408_2502_tagged_public.pkl',
                        help='Path to the dataset')
    parser.add_argument('--output_path', type=str,
                        default='/path/to/file/eval',
                        help='Path to save successful solutions')
    parser.add_argument('--model_name', type=str,
                        default='MaaS_o3',
                        help='Model name for generation')
    parser.add_argument('--batch_size', type=int,
                        default=6,
                        help='Number of samples to process in each batch')
    parser.add_argument('--max_attempts', type=int,
                        default=2,
                        help='Maximum attempts per sample before giving up')
    args = parser.parse_args()

    # Ensure the output directory exists
    output_dir = os.path.join(args.output_path, args.model_name)
    os.makedirs(output_dir, exist_ok=True)

    # Initialize API client
    client = OpenAI(
        base_url="https://genaiapi.cloudsway.net/v1/ai/OtasDekxkGHaKvMQ",
        api_key='XDNhQFIHX6tLpWHjvCCg'
    )

    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
    if args.data_path.endswith(".pkl"):
        dataset = pd.read_pickle(args.data_path)
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)
    elif args.data_path.endswith(".jsonl"):
        dataset = [json.loads(x) for x in open(args.data_path)]
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)
    elif args.data_path.endswith(".parquet"):
        dataset = pd.read_parquet(args.data_path)
    else:
        raise ValueError(f'Unsupported file format: {args.data_path}')

    # Prepare the list of prompts
    chat_lst = dataset['prompt'].tolist()
    for idx in range(len(chat_lst)):
        chat_lst[idx] = chat_lst[idx].tolist()

    # Initialize sample tracker, now passing output_dir
    tracker = SampleTracker(
        dataset, max_attempts=args.max_attempts, output_dir=output_dir)

    # Function for parallel score calculation
    compute_score_func = partial(
        compute_score_yr,
        model_type="Qwen",
    )

    # Processing loop
    pbar = tqdm(total=len(dataset))
    while not tracker.is_finished():
        # Get a batch of samples
        batch_indices = tracker.get_batch(args.batch_size)
        if not batch_indices:
            break

        # Prepare batch prompts
        batch_prompts = [chat_lst[idx] for idx in batch_indices]

        completions = multithread_openai_chat_completions_call(
            client,
            batch_prompts,
            model_name=args.model_name,
            max_workers=args.batch_size
        )

        # Extract generated content
        generation_texts = []
        for i, response in enumerate(completions):
            try:
                generation_texts.append(response.choices[0].message.content)
            except (AttributeError, IndexError, TypeError) as e:
                print(f"Error processing response at index {i}: {str(e)}")
                # Add an empty string as a placeholder
                generation_texts.append("")

        # Parallel score calculation
        ground_truths = [dataset.iloc[idx]['reward_model']
                         ['ground_truth'] for idx in batch_indices]
        scores, metadata = parallel_compute_score(
            compute_score_func,
            generation_texts,
            ground_truths,
            [1]*len(ground_truths),
            max_workers=min(len(batch_indices), int(
                multiprocessing.cpu_count() * 0.8)),
            return_metadata=True,
            timeout=10
        )

        # Process results
        updates_made = False
        for idx, score, meta, text in zip(batch_indices, scores, metadata, generation_texts):
            # Check if successful (score is 1)
            succeeded = (score == 1)

            if succeeded:
                # Save successful solutions
                question_id = dataset.iloc[idx]['reward_model']['ground_truth']['question_id']
                solution_file = os.path.join(
                    output_dir, f"{question_id}.json")
                with open(solution_file, 'w') as f:
                    json.dump({
                        'question_id': question_id,
                        'solution': text,
                        'score': score,
                        'metadata': meta
                    }, f, indent=2)

            # Update sample status
            status = tracker.update_status(idx, succeeded)
            if status is not None:  # Sample is removed from the pending queue
                updates_made = True
                pbar.update(1)

        # Print current status
        if updates_made:
            stats = tracker.get_stats()
            pbar.set_postfix({
                'pending': stats['pending'],
                'success': stats['success'],
                'failure': stats['failure']
            })

    pbar.close()

    # Print final statistics
    final_stats = tracker.get_stats()
    print("\nProcessing completed!")
    print(tabulate([
        ["Total samples", final_stats['total']],
        ["Successful", final_stats['success']],
        ["Failed", final_stats['failure']],
        ["Success rate",
            f"{final_stats['success']/final_stats['total']*100:.2f}%"]
    ], headers=["Metric", "Value"], tablefmt="grid"))


if __name__ == '__main__':
    main()
