import json
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np
import sys 

# Load the CSV data
df = pd.read_csv("/fast/XXXX-3/forecasting/datasets/menge/mc_questions_2020-01-01_2024-12-31.csv")

# Add idx column
df['idx'] = df.index

# Print all the columns names 
print(df.columns)

# Print df length
print("OG length:", len(df))

print(df['total_points'].value_counts())

# only keep rows with the highest total_points (first find it, and then filter)
df = df[df['total_points'] == df['total_points'].max()]

# make new column resolution which is capital letter of answer
df['resolution'] = df['answer'].str.upper()

# Create options column which is choice_a, choice_b, choice_c, choice_d (column names) in a list 
df['options'] = df[['choice_a', 'choice_b', 'choice_c', 'choice_d']].apply(lambda x: list(x), axis=1)

print(df['answer'].value_counts())
# relevant columns: question, answer, date

# print first 2 rows of df
for idx, row in df.iterrows():
    for col in df.columns:
        print(col, row[col])
    print("-"*100)
    if idx > 2:
        break
# df['prefix'] = df['question'].apply(lambda x: x[:5])

# print(df['prefix'].value_counts())

# Count the number of question which don't have a year in it 
print("Questions without year:", len(df[~df['question'].str.contains(r'\d{4}')]))
# Remove these rows
df = df[df['question'].str.contains(r'\d{4}')]

# print resolution value counts
print(df['resolution'].value_counts())

print(df['total_points'].value_counts())

# Convert 'created_date' to datetime (timezone-aware)
df['date'] = pd.to_datetime(df['date'], format='ISO8601')

# Split into train, val, test
# Split into three datasets: train, val, test
df_train = df[df['date'] < pd.Timestamp('2024-07-01')]
df_test = df[df['date'] >= pd.Timestamp('2024-07-01')]

# validation is subsampled from train
df_val = df_train.sample(frac=0.1)
df_train = df_train.drop(df_val.index)

# print length of each dataset
print(len(df_train), len(df_val), len(df_test))


# Print all the columns names 
print(df.columns)

# Remove rows whose questions that contain the text "will my", "in my" or "will i" (case insensitive)
print("Removing rows with 'will my', 'in my', or 'will i'")
for df1 in [df_train, df_val, df_test]:
    print("before", len(df1))
    df1 = df1[~df1['question'].str.lower().str.contains("will my ")].copy()
    df1 = df1[~df1['question'].str.lower().str.contains("in my ")].copy()
    df1 = df1[~df1['question'].str.lower().str.contains("will i ")].copy()
    print("after", len(df1))
    
print(len(df_train), len(df_val), len(df_test))


# # Print df_final length
# print("Filtered length after ensuring at least three words in title:", len(df_final))

# Remove rows whose question or background contain ambiguous unicode characters
# df_train = df_train[~df_train['question'].str.contains("[\u200B-\u200D\uFEFF]")].copy()

# Print df_final length
# print("Filtered length after removing ambiguous unicode characters:", len(df_train), len(df_val), len(df_test))

# print resolution value counts
print("Train resolution value counts:", df_train['resolution'].value_counts())
print("Val resolution value counts:", df_val['resolution'].value_counts())
print("Test resolution value counts:", df_test['resolution'].value_counts())

# Print the earliest and latest date in df_train
print("Earliest date in df_train:", df_train['date'].min())
print("Latest date in df_train:", df_train['date'].max())

# Print the earliest and latest date in df_val
print("Earliest date in df_val:", df_val['date'].min())
print("Latest date in df_val:", df_val['date'].max())

# Print the earliest and latest date in df_test
print("Earliest date in df_test:", df_test['date'].min())
print("Latest date in df_test:", df_test['date'].max())

def format_forecasting_prompt(
    question: str,
    background: str,
    resolution_criteria: str,
    date_begin: str,
    date_close: str,
    zero_shot: bool = True
) -> str:
    """
    Format the prompt given the row data.
    """
    if zero_shot:
        return f"""You will be asked a forecasting question. You have to come up with the best estimate for whether the event asked in the question happens or happened. 

Question: {question}

Output your final prediction (a number between 0 and 1) with an asterisk at the beginning and end of the decimal. YOUR FINAL PREDICTION SHOULD STRICTLY BE BETWEEN 0 AND 1. For example, if you believe the answer is 75% likely, you would write *0.75*. MAKE SURE TO FORMAT IT CORRECTLY AND PLACE BETWEEN ASTERISKS.
"""
    else:
        # If not zero_shot, you can modify the prompt as needed.
        return f"""
Question: {question}
"""




def format_forecasting_prompt(
    question: str,
    background: str,
    resolution_criteria: str,
    date_begin: str,
    date_close: str,
    zero_shot: bool = False,
    options: list[str] = [],
) -> str:
    """
    Format the prompt given the row data.
    """
    
    middle_text = ""
    for i, option in enumerate(options):
        middle_text += f"{chr(i + ord('A'))}. {option}\n"
    
    if zero_shot:
        return f"""You will be asked a forecasting question in multiple choice format. You have to choose the most likely option from the given options and also report your confidence level in your answer.

Think thoroughly about each of the options and finally format your answer in the following format:

<answer1>
Provide exactly one option number from the choices above (e.g., A, B, C, etc.)
</answer1>
<answer2>
Provide your confidence level in this answer as a decimal between 0 and 1 (e.g., 0.7 for 70% confidence)
</answer2>

IMPORTANT:
- Your <answer1> MUST be exactly one of the option numbers listed above.
- Your <answer2> MUST be a decimal between 0 and 1 representing your confidence.
- Format your response exactly as shown with the <answer1> and <answer2> tags.

Question: {question}
{middle_text}
"""
    else:
        # If not zero_shot, you can modify the prompt as needed.
        return f"""
Question: {question}
{middle_text}
"""


# Create dataset for huggingface
import pandas as pd

# Randomly sample 1000 rows from df_final
# df_final = df_final.sample(n=1000)

data_lists = []

for index, df1 in enumerate([df_train, df_val, df_test]):
    data_list = []
    for idx, row in df1.iterrows():
        # Use title as the question, and body as the background.
        question = row["question"]
        background = "Not available"
        
        resolution_criteria = "Not available"
        # Convert resolution to binary (1 for yes, 0 for no)
        resolution = row["resolution"]
        # Convert date to string format to avoid JSON serialization issues with Timestamp objects
        date_resolve = str(row["date"]) if pd.notna(row["date"]) else None
        if date_resolve is None:
            print(f"Date resolve is None for row {idx}")
            continue
        
        date_resolve = date_resolve.split(" ")[0]
  
        # Extract URLs from background if any exist
        import re
        urls = re.findall(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', row["text"])

        options = row["options"]
        answer_idx = ord(resolution) - ord('A')
        
        # Create dictionary for this example
        example_dict = {
            'date_resolve_at': date_resolve,
            'extracted_urls': urls,
            'question_type': "binary",
            'url': row["url"],
            'background': background, # row["body"],
            'resolution_criteria': resolution_criteria,
            'is_resolved': True,
            'date_close': date_resolve,
            'question': question,
            'data_source': "menge_mcq",
            'resolution': resolution,
            'idx': row["idx"],
            'options': options,
            'answer_idx': answer_idx,
            'answer': options[answer_idx],
        }

        # Create the prompt (change zero_shot to True if desired)
        prompt = format_forecasting_prompt(
            question=question,
            background=background,
            resolution_criteria=resolution_criteria,
            date_begin=str(date_resolve),
            date_close=str(date_resolve),
            zero_shot=False,
            options=options,
        )
        
        # Append the prompt and resolution together.
        combined_output = prompt 
        example_dict['prompt'] = prompt
        example_dict['full_prompt'] = format_forecasting_prompt(
            question=question,
            background=background,
            resolution_criteria=resolution_criteria,
            date_begin=str(date_resolve),
            date_close=str(date_resolve),
            zero_shot=True,
            options=options,
        )
        
        data_list.append(example_dict)
    
    data_lists.append(data_list)

    import random
    random.shuffle(data_list)
    sample_data_list = data_list[:2]
    
    # Save data_list in proper format
    suffix = "train" if index == 0 else "validation" if index == 1 else "test"
    
    # print(f"Example prompts of {suffix} set:\n\n")
    
    for example in sample_data_list:
        print(example['prompt'])
        print("Date resolve at:", example['date_resolve_at'])
        print("Resolution:", example['resolution'])
        # print(example['full_prompt'])
        # print url,
        print("URL:", example['url'])
        # print(example['idx'])
        print("-"*100)
    
    # file_path = f"/fast/XXXX-3/forecasting/datasets/menge/mcq_{suffix}.json"
    # # print(f"Saving to {file_path} with data_list length {len(data_list)}")
    # with open(file_path, 'w') as f:
    #     print(f"Saving to {file_path} with data_list length {len(data_list)}")
    #     json.dump(data_list, f, indent=4, ensure_ascii=False)

# # Only keep a random 1000 rows
# # import random
# # random.shuffle(data_list)
# # data_list = data_list[:1000]
