# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     XXXX
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess our own forecasting dataset to parquet format
"""

import argparse
import logging
import os
import re
import json
from typing import List
import random
import datasets
import sys

from verl.utils.hdfs_io import copy, makedirs

# Set up logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

def format_forecasting_prompt_binary(
    question_title: str,
    background: str,
    resolution_criteria: str,
    question_start_date: str,
) -> str:
    """
    Format the prompt without article context.
    """
    
    prompt = f"""You will be asked a BINARY forecasting question.  You have to come up with the best estimate for whether the event asked in the question happens or happened. Please provide your reasoning before stating how likely is the event asked in the question to happen (your probability estimate of it resolving YES).
        
Question Title: {question_title}
Question Start Date: {question_start_date}
Question Background: {background}
Resolution Criteria: {resolution_criteria}

Think step by step about the information provided, reason about uncertainty and put your final probability for the event asked in the question to resolve YES in <probability> </probability> tags. The probability should be a number between 0 and 1.

You will be rewarded based on the probability (p) you assign to your answer. Your answer will be evaluated using the BRIER SCORING RULE which is basically - (1 - p)^2 if your answer is correct and (- (p^2)) if your answer is incorrect. For example, if p = 0.6, and the event resolves to NO, then your score will be (- (0.6^2)) = -0.36 whereas if the event resolves to YES, then your score would be - (1 - 0.6)^2 = -0.16. Thus, the range of the score is [-1, 0]. If you output probability more than 0.5, then it is assumed that you think the event will likely resolve to "YES" while if you output probability less than 0.5, then it is assumed that you think the event will likely resolve to "NO". YOU HAVE TO MAXIMIZE YOUR BRIER SCORE.

Your final answer should be the probability that the event asked will resolve to YES and your response SHOULD STRICTLY END with <probability> </probability> tags.
"""

    return prompt

def format_forecasting_prompt(
    question_title: str,
    background: str,
    resolution_criteria: str,
    answer_type: str,
) -> str:
    """Format the prompt for single outcome forecasting."""
    
    prompt = f"""You will be asked a forecasting question (which might be from the past). You have to come up with the best guess for the final answer. Please provide your reasoning before stating your final answer and also express how likely you think your answer is to be correct (your confidence in your answer).
        
Question Title: {question_title}
Question Background: {background}
Resolution Criteria: {resolution_criteria}
Expected Answer Type: {answer_type}

Think step by step about the information provided, reason about uncertainty and put your final answer (in the format asked) in <answer> </answer> tags. You should also specify your confidence in your answer in <probability> </probability> tags. The probability should be a number between 0 and 1.

You will be rewarded based on the probability (p) you assign to your answer. Your answer will be evaluated using the BRIER SCORING RULE which is basically (- (1 - p)^2) if your answer is correct and (- 1 - p^2) if your answer is incorrect. For example, if p = 0.5, and your answer is incorrect, then your score will be (-1 - 0.5^2) = (-1 - 0.25) = -1.25 whereas if the answer was correct, then your score would be (- (1 - 0.5)^2) = (- (0.5)^2) = -0.25. Thus, the range of the score is [-2, 0] where your score lies between [-2, -1] if the answer is incorrect and [-1, 0] if the answer is correct. If your answer is correct, your will be REWARDED more if your probability is higher whereas if your answer is incorrect, your will be PENALIZED more if your probability is higher. TRY HARD TO COME UP WITH THE BEST GUESS FOR THE FINAL ANSWER. YOU HAVE TO MAXIMIZE YOUR SCORE.

Your final answer should be concise (NOT MORE THAN A FEW WORDS LONG) and your response SHOULD STRICTLY END with <answer> </answer> tags and <probability> </probability> tags."""

    return prompt



def format_forecasting_prompt_multiple_outcomes(
    question_title: str,
    background: str,
    resolution_criteria: str,
    answer_type: str,
    max_outcomes: int = 3,
) -> str:
    """Format the prompt for multiple outcomes forecasting."""
    
    prompt = f"""You will be asked a forecasting question. Please provide your reasoning before stating your final answer.

Think step by step about the information provided. You are expected to reason about the possible outcomes and list your best estimate of how likely each of them are. Thus, you have to provide a list of mostly likely outcomes (UP TO {max_outcomes}) and their forecasted probability for each of them. YOUR PROBABILITIES MUST SUM LESS THAN OR EQUAL TO 1 (BUT NOT NECESSARILY 1).

Your will be rewarded based on your probability listed for the different outcomes in reference to the actual (true) outcome of the event. The rule to evaluate your answer will be the multi-class brier scoring rule which is basically: (- \\sum_k (p_k - y_k)^2) where p_k is the probability you assigned to the k^th outcome and y_k is 1 if the k^th outcome is the true outcome and 0 otherwise. If you don't assign list the actual true outcome (i.e., don't assign any probability to it), then it assumed that your probability for that outcome is 0. Thus, the best score you can achieve is 0 by assigning all the probability to the correct outcome (getting (1 - 1)^2 = 0 for the correct outcome and 0 for the incorrect outcomes, so total is 0). Whereas, the worst score you can get is -2 by assigning all the probability to a single incorrect outcome (getting -(1 - 0)^2 = -1 for the incorrect outcomes and -((0 - 1)^2) = -1 for the correct outcome, so total is -2). Hence, the range of your score is [-2, 0]. THINK HARD ABOUT THE POSSIBLE OUTCOMES AND THEIR CHANCE OF HAPPENING. YOU HAVE TO MAXIMIZE YOUR SCORE. YOU MAY PROVIDE A SINGLE PREDICTION OR MULTIPLE PREDICTIONS BUT NOT MORE THAN YOUR TOP {max_outcomes} PREDICTIONS (CAN BE LESS THAN {max_outcomes}). BUT ALSO ENSURE THAT YOUR PROBABILITIES DO NOT SUM MORE THAN 1.

**Example**

**Example Question**:
Question Title: Who will win the Nobel Prize in Literature in 2016?
Background: Question Start Date: 10th January 2016. The Nobel Prize in Literature is awarded annually by the Swedish Academy to authors for their outstanding contributions to literature.
Resolution Criteria: The question will resolve when the Swedish Academy publicly announces the official 2016 Nobel Prize in Literature laureate(s) typically via a press release on NobelPrize.org (expected on or about October 13, 2016). The full name of the laureate exactly as given in the announcement should be provided.

**Example Reasoning and Output**:
...Based on recent literary achievements and past Nobel Prize patterns, I'll consider the most likely candidates. The Nobel Prize in Literature often recognizes established authors with significant global impact, though it can sometimes surprise with unconventional choices.

<answer1> Haruki Murakami </answer1> <probability1> 0.35 </probability1>
<answer2> Philip Roth </answer2> <probability2> 0.25 </probability2>
<answer3> Joyce Carol Oates </answer3> <probability3> 0.18 </probability3>
<answer4> Bob Dylan </answer4> <probability4> 0.11 </probability4>
<answer5> Ngugi wa Thiong'o </answer5> <probability5> 0.09 </probability5>

**Example Score**:
The sum of the probabilities is 0.35 + 0.25 + 0.18 + 0.11 + 0.09 = 0.6 + 0.29 + 0.09 = 0.98. The correct answer to the question is Bob Dylan. The probability assigned to Bob Dylan is 0.11. Hence, the score is - ( (1 - 0.11)^2 + (0 - 0.35)^2 + (0 - 0.25)^2 + (0 - 0.18)^2 + (0 - 0.09)^2 ) = - ( (0.89)^2 + (-0.35)^2 + (-0.25)^2 + (-0.18)^2 + (-0.09)^2 ) = - ( 0.7921 + 0.1225 + 0.0625 + 0.0324 + 0.0081 ) = - 1.0176.

**ACTUAL QUESTION**:

Question Title: {question_title}
Question Background: {background}
Resolution Criteria: {resolution_criteria}
Expected Answer Type: {answer_type}

**Output Format**
... {{reasoning}} ...
<answer1> outcome1 </answer1> <probability1> probability1 </probability1>
<answer2> outcome2 </answer2> <probability2> probability2 </probability2>
<answer3> outcome3 </answer3> <probability3> probability3 </probability3>
...
<answerN> outcomeN </answerN> <probabilityN> probabilityN </probabilityN>

{{ IMPORTANT: probability1 + probability2 + ... + probabilityN <= 1 ; N <= {max_outcomes} }}

EACH ANSWER YOU LIST SHOULD BE AN ACTUAL POSSIBLE ANSWER/OUTCOME OF THE QUESTION/EVENT ASKED AND IT SHOULD NOT BE "other", "rest", "etc.". EACH (DISTINCT) ANSWER YOU OUTPUT WILL BE COUNTED AS A SEPARATE OUTCOME CLASS. YOU WILL BE SCORED NEGATIVELY IF YOU SAY "other", "rest", etc. OR PROVIDE MORE THAN {max_outcomes} PREDICTIONS.
"""

    return prompt



def load_manifold_and_metaculus_data(split="train", vol_filter_limit = 4000):
    path1 = "/fast/XXXX-3/forecasting/datasets/manifold/manifold_binary_train_with_r1queries.json"
    path2 = "/fast/XXXX-3/forecasting/datasets/metaculus/old-2024/binary_"
    
    path2 += "raw_"
        
    path2 += split + ".json"
    
    
    # Load dataset using Dataset.from_json() instead of load_dataset()
    ds1 = datasets.Dataset.from_json(path1)
    ds2 = datasets.Dataset.from_json(path2)
    
    # print length of individual datasets
    print("Length of dataset 1: ", len(ds1))
    print("Length of dataset 2: ", len(ds2))
    
    # filter ds1 based on volume
    ds1 = ds1.filter(lambda x: x["volume"] >= vol_filter_limit)
    print("Length of dataset 1 after filtering: ", len(ds1))
    
    # only keep columns which are present in both datasets
    to_remove = []
    for col in ds1.column_names:
        if col not in ds2.column_names:
            to_remove.append(col)
            
    ds1 = ds1.remove_columns(to_remove)
    
    to_remove = []
    for col in ds2.column_names:
        if col not in ds1.column_names:
            to_remove.append(col)
            
    ds2 = ds2.remove_columns(to_remove)
    
    # Concatenate the two datasets
    # ds = datasets.concatenate_datasets([ds1, ds2])
    ds = ds1 
    print("Length of concatenated dataset: ", len(ds))
    
    # Print the row with the longest prompt length
    # max_length = 0
    # max_row = None
    # for row in ds:
    #     length = len(row["question"]) + len(row["background"]) + len(row["resolution_criteria"])
    #     if length > max_length:
    #         max_length = length
    #         max_row = row
    
    # print("Max row: ", max_row)
    # print("Max length: ", max_length)
    
    # Remove all rows where prompt length (question + background + resolution criteria) is more than 20000 characters
    ds = ds.filter(lambda x: len(x["question"]) + len(x["background"]) + len(x["resolution_criteria"]) <= 10000)
    print("Length of dataset after removing rows with prompt length > 10000: ", len(ds))
    
    # print("Length of dataset after removing duplicates: ", len(ds))
    return ds
    


def load_questions_from_jsonl(file_path: str) -> List[dict]:
    """Load articles with questions from JSONL file and extract question components."""
    questions_data = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_idx, line in enumerate(f):
            if line.strip():
                try:
                    article = json.loads(line.strip())
                    
                    # Skip if question is not relevant or the article is not relevant
                    # if 'question_relevant' in article and int(article['question_relevant']) == 0:
                    #     continue
                    
                    # if 'article_relevant' in article and int(article['article_relevant']) == 0:
                    #     continue
                    
                    if 'no_good_question' in article and int(article['no_good_question']) == 1:
                        continue
                    
                    # Create a question entry with all necessary fields
                    question_entry = {
                        'idx': line_idx,
                        'question_title': article.get('question_title', article.get('question', '')),
                        'background': article.get('background', ''),
                        'resolution_criteria': article.get('resolution_criteria', ''),
                        'answer': article.get('answer', ''),
                        'answer_type': article.get('answer_type', ''),
                        'resolution_date': article.get('resolution_date', ''),
                        'question_start_date': article.get('question_start_date', article.get('date_begin', '')),
                        'question_idx': article.get('question_idx', line_idx),
                        'url': article.get('url', ''),
                        'data_source': article.get('data_source', ''),
                        'news_source': article.get('news_source', ''),
                        'resolution': int(article.get('resolution', -1)),
                    }
                    
                    # Only add if we have a valid question title
                    if question_entry['question_title'].strip():
                        questions_data.append(question_entry)
                        
                except json.JSONDecodeError as e:
                    logger.warning(f"Failed to parse line {line_idx}: {e}")
                    continue
    
    logger.info(f"Loaded {len(questions_data)} valid questions from {file_path}")
    return questions_data


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--questions_file', type=str, default="/fast/XXXX-3/forecasting/datasets/synthetic/freeform/cnn_dw_forbes/cnndwforbescombined_20000.jsonl",
                       help="Path to JSONL file containing articles with question fields")
    parser.add_argument("--local_dir", default="/fast/XXXX-3/forecasting/datasets/verl/list-outcomes/cnn_dw_forbes")
    parser.add_argument("--hdfs_dir", default=None)

    args = parser.parse_args()
    
    split = "train"
    if "validation" in args.questions_file:
        split = "validation"
    if "test" in args.questions_file:
        split = "test"

    questions_data = load_questions_from_jsonl(args.questions_file)
    
    if "train" in args.questions_file:
        # sort questions_data by start date
        questions_data.sort(key=lambda x: x["resolution_date"])
    
    print(questions_data[0])
    print(questions_data[-1])
    
    # if "train" in args.questions_file:
    #     binary_data = load_manifold_and_metaculus_data(split="train", vol_filter_limit=4000)
    #     # sort binary_data by start date
    #     binary_data = binary_data.sort("date_begin")
    #     binary_data = binary_data.select(range(len(binary_data) - 2000, len(binary_data)))
    #     # print columns of questions_data
        
    #     print(binary_data[0])
    #     print(binary_data[-1])
    #     print(binary_data.column_names)
        
    #     # convert binary_data to list of dicts
    #     binary_data = binary_data.to_list()
        
    #     # combine questions_data and binary_data
    #     questions_data.extend(binary_data)
        
    
    if "validation" in args.questions_file or "test" in args.questions_file:
        # Also load metaculus test set, and theguardian test set
        metaculus_test_data = load_questions_from_jsonl("/fast/XXXX-3/forecasting/datasets/metaculus/fromMay2025/binary_test.jsonl")
        theguardian_test_data = load_questions_from_jsonl("/fast/XXXX-3/forecasting/newsdata/theguardian/qgen/cleaned/o4-mini-high_theguardian_207_free_3_cleaned.jsonl")
        
        questions_data.extend(metaculus_test_data)
        # print(len(metaculus_test_data))
        questions_data.extend(theguardian_test_data)
        # print(len(theguardian_test_data))
        
        random.seed(42)
        random.shuffle(questions_data)
        # print("length of questions_data: ", len(questions_data))
        
    # else :
    #     metaculus_train_data = load_questions_from_jsonl("/fast/XXXX-3/forecasting/datasets/synthetic/freeform/binary_train_2k.jsonl")
    #     questions_data.extend(metaculus_train_data)

    # Convert list of dicts to a HuggingFace Dataset object
    print("length of questions_data: ", len(questions_data))
    check_metaculus = []
    num_binary = 0
    num_freeform = 0

    # Helper to safely cast to int, returns fallback if fails
    def safe_int(val, fallback=-1):
        try:
            if val is None:
                return fallback
            return int(val)
        except (ValueError, TypeError):
            return fallback

    # add a row to each data item that represents a unique id
    def make_map_fn(split):
        def process_fn(example, idx):
            global num_binary, num_freeform
            question_raw = example.get("question_title", example.get("question", ""))
            background = example.get("background", "")
            resolution_criteria = example.get("resolution_criteria", "")

            # Ensure resolution is always an int or fallback
            resolution_val = example.get("resolution", -2)
            resolution = safe_int(resolution_val, -2)
            if resolution >= -0.1: # Resolution is 0 or 1
                answer = "YES" if resolution == 1 else "NO"
            else:
                answer = example.get("answer", "UNKNOWN")

            # answer_type should be a string
            answer_type = example.get("answer_type", "binary (YES or NO)")
            
            # if "no" == answer.lower():
            #     # if "binary" not in answer_type.lower():
            #     # print("answer: ", answer)
            #     print("answer_type: ", answer_type)
            #     print("question: ", example.get("question", ""))
            #     print("question_title: ", example.get("question_title", ""))
            #     # print("background: ", example.get("background", ""))
            #     # print("resolution_criteria: ", example.get("resolution_criteria", ""))
            #     # print("resolution_date: ", example.get("resolution_date", ""))
            #     # assert False 
            
            resolution_date = example.get(
                "resolution_date",
                example.get("date_resolve_at", example.get("question_close_date", ""))
            )
            question_start_date = example.get(
                "question_start_date",
                example.get("date_begin", "")
            )
            question_idx = example.get("question_idx", example.get("url", ""))

            data_source = example.get("data_source", "manifold")
            if data_source is None:
                data_source = "manifold"
            data_source = str(data_source).lower()

            if "binary" in answer_type.lower():
                num_binary += 1
            else:
                num_freeform += 1

            news_source = example.get("news_source", "Unknown")
            if news_source is None:
                news_source = "Unknown"
            news_source = str(news_source).lower()

            # Ensure resolution is always an int or fallback
            resolution = safe_int(example.get("resolution", -1), -1)

            prompt = format_forecasting_prompt_multiple_outcomes(question_raw, background, resolution_criteria, answer_type)

            data_field = "freeform/cnn_dw_forbes"
            if "manifold" in data_source:
                data_field = "binary/manifold-train"
                # prompt = format_forecasting_prompt_binary(question_raw, background, resolution_criteria, question_start_date)
            
            elif "metaculus" in data_source:
                data_field = "binary/metaculus-test"
                # prompt = format_forecasting_prompt_binary(question_raw, background, resolution_criteria, question_start_date)
                check_metaculus.append(example)
            
            elif "theguardian" in news_source:
                data_field = "freeform/theguardian-test"
            elif "forbes24" in data_source:
                data_field = "freeform/forbes24"
            else:
                data_field = "freeform/cnn_dw_forbes"

            # Ensure all fields in extra_info are of correct type for Arrow
            # Especially: index (int), resolution (int), question_idx (str)
            # Also, avoid any accidental non-string or non-int types

            data = {
                "data_source": data_field,
                "prompt": [
                    {
                        "role": "user",
                        "content": prompt,
                    }
                ],
                "ability": "forecasting",
                "reward_model": {"style": "rule", "ground_truth": answer},
                "extra_info": {
                    "split": str(split),
                    "index": int(idx),
                    "answer_type": str(answer_type),
                    "question_idx": str(question_idx),
                    "answer": str(answer),
                    "question": str(question_raw),
                    "background": str(background),
                    "resolution_criteria": str(resolution_criteria),
                    "resolution_date": str(resolution_date),
                    "question_source": str(data_field),
                    "resolution": int(resolution),
                    "question_start_date": str(question_start_date),
                    "multiple_outcomes": True,
                },
            }
            return data
        return process_fn

    # 'questions_data' is a list, so use list comprehension instead of .map
    processed_dataset = [make_map_fn(split)(example, idx) for idx, example in enumerate(questions_data)]
    
    # sort the list by resolution date
    # processed_dataset.sort(key=lambda x: x["extra_info"]["resolution_date"])
    
    processed_dataset = datasets.Dataset.from_list(processed_dataset)
    
    # print first 2 rows of processed_dataset
    print(processed_dataset[0])
    print(processed_dataset[1])
    print(processed_dataset[2])
    print(processed_dataset[3])
    print(processed_dataset[-2])
    print(processed_dataset[-1])
    print("num_binary: ", num_binary)
    print("num_freeform: ", num_freeform)
    # print(check_metaculus[0])
    # print(check_metaculus[-1])

    local_dir = args.local_dir
    hdfs_dir = args.hdfs_dir

    # Save with same name as input file
    if not os.path.exists(local_dir):
        os.makedirs(local_dir, exist_ok=True)
        
    input_filename = os.path.basename(args.questions_file)
    # input_filename = "both_filtered_train_12k.jsonl" # os.path.basename(args.questions_file)
    output_path = os.path.join(local_dir, input_filename)
    processed_dataset.to_parquet(output_path)

    if hdfs_dir is not None:
        makedirs(hdfs_dir)

        copy(src=local_dir, dst=hdfs_dir)
