import json
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np

# Load the CSV data
df = pd.read_csv("/fast/XXXX-3/forecasting/datasets/menge/tf_questions_2020-01-01_2024-12-31.csv")

# Add idx column
df['idx'] = df.index

# Print all the columns names 
print(df.columns)

# Print df length
print("OG length:", len(df))

print(df['total_points'].value_counts())

# only keep rows with the highest total_points (first find it, and then filter)
df = df[df['total_points'] == df['total_points'].max()]

# relevant columns: question, answer, date
# make new column resolution, which is 1 if answer is yes, 0 if answer is no
df['resolution'] = df['answer'].apply(lambda x: 1 if "yes" in x.lower() else 0)

# df['prefix'] = df['question'].apply(lambda x: x[:5])

# print(df['prefix'].value_counts())

# Count the number of question which don't have a year in it 
print("Questions without year:", len(df[~df['question'].str.contains(r'\d{4}')]))
# Remove these rows
df = df[df['question'].str.contains(r'\d{4}')]

# print resolution value counts
print(df['resolution'].value_counts())

print(df['total_points'].value_counts())

# Convert 'created_date' to datetime (timezone-aware)
df['date'] = pd.to_datetime(df['date'], format='ISO8601')

# Split into train, val, test
# Split into three datasets: train, val, test
df_train = df[df['date'] < pd.Timestamp('2024-07-01')]
df_test = df[df['date'] >= pd.Timestamp('2024-07-01')]

# validation is subsampled from train
df_val = df_train.sample(frac=0.1)
df_train = df_train.drop(df_val.index)

# print length of each dataset
print(len(df_train), len(df_val), len(df_test))


# Print all the columns names 
print(df.columns)

# Remove rows whose questions that contain the text "will my", "in my" or "will i" (case insensitive)

for df1 in [df_train, df_val, df_test]:
    print("before", len(df1))
    df1 = df1[~df1['question'].str.lower().str.contains("will my ")].copy()
    df1 = df1[~df1['question'].str.lower().str.contains("in my ")].copy()
    df1 = df1[~df1['question'].str.lower().str.contains("will i ")].copy()
    print("after", len(df1))
    
print(len(df_train), len(df_val), len(df_test))


# # Print df_final length
# print("Filtered length after ensuring at least three words in title:", len(df_final))

# Remove rows whose question or background contain ambiguous unicode characters
# df_train = df_train[~df_train['question'].str.contains("[\u200B-\u200D\uFEFF]")].copy()

# Print df_final length
# print("Filtered length after removing ambiguous unicode characters:", len(df_train), len(df_val), len(df_test))

# print resolution value counts
print(df_train['resolution'].value_counts())
print(df_val['resolution'].value_counts())
print(df_test['resolution'].value_counts())

def format_forecasting_prompt(
    question: str,
    background: str,
    resolution_criteria: str,
    date_begin: str,
    date_close: str,
    zero_shot: bool = True
) -> str:
    """
    Format the prompt given the row data.
    """
    if zero_shot:
        return f"""You will be asked a forecasting question. You have to come up with the best estimate for whether the event asked in the question happens or happened. 

Question: {question}

Output your final prediction (a number between 0 and 1) with an asterisk at the beginning and end of the decimal. YOUR FINAL PREDICTION SHOULD STRICTLY BE BETWEEN 0 AND 1. For example, if you believe the answer is 75% likely, you would write *0.75*. MAKE SURE TO FORMAT IT CORRECTLY AND PLACE BETWEEN ASTERISKS.
"""
    else:
        # If not zero_shot, you can modify the prompt as needed.
        return f"""
Question: {question}
"""

# Create dataset for huggingface
import pandas as pd

# Randomly sample 1000 rows from df_final
# df_final = df_final.sample(n=1000)

data_lists = []

for index, df1 in enumerate([df_train, df_val, df_test]):
    data_list = []
    for idx, row in df1.iterrows():
        # Use title as the question, and body as the background.
        question = row["question"]
        background = "Not available"
        
        resolution_criteria = "Not available"
        # Convert resolution to binary (1 for yes, 0 for no)
        resolution = row["resolution"]
        # Convert date to string format to avoid JSON serialization issues with Timestamp objects
        date_resolve = str(row["date"]) if pd.notna(row["date"]) else None
        if date_resolve is None:
            continue
        
        date_resolve = date_resolve.split(" ")[0]
  
        # Extract URLs from background if any exist
        import re
        urls = re.findall(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', row["text"])
    
        # Create dictionary for this example
        example_dict = {
            'date_resolve_at': date_resolve,
            'extracted_urls': urls,
            'question_type': "binary",
            'url': row["url"],
            'background': background, # row["body"],
            'resolution_criteria': resolution_criteria,
            'is_resolved': True,
            'date_close': date_resolve,
            'question': question,
            'data_source': "menge_binary",
            'resolution': resolution,
            'idx': row["idx"],
        }

        # Create the prompt (change zero_shot to True if desired)
        prompt = format_forecasting_prompt(
            question=question,
            background=background,
            resolution_criteria=resolution_criteria,
            date_begin=str(date_resolve),
            date_close=str(date_resolve),
            zero_shot=False,
        )
        
        # Append the prompt and resolution together.
        combined_output = prompt 
        example_dict['prompt'] = prompt
        example_dict['full_prompt'] = format_forecasting_prompt(
            question=question,
            background=background,
            resolution_criteria=resolution_criteria,
            date_begin=str(date_resolve),
            date_close=str(date_resolve),
            zero_shot=True,
        )
        
        data_list.append(example_dict)
    
    data_lists.append(data_list)

    import random
    random.shuffle(data_list)
    data_list = data_list[:4]
    
    # Save data_list in proper format
    suffix = "train" if index == 0 else "validation" if index == 1 else "test"
    
    # print(f"Example prompts of {suffix} set:\n\n")
    
    for example in data_list:
        print(example['prompt'])
        print("Date resolve at:", example['date_resolve_at'])
        print("Resolution:", example['resolution'])
        # print(example['full_prompt'])
        # print url,
        print("URL:", example['url'])
        # print(example['idx'])
        print("-"*100)
    
    # file_path = f"/fast/XXXX-3/forecasting/datasets/menge/binary_{suffix}.json"
    # # print(f"Saving to {file_path} with data_list length {len(data_list)}")
    # with open(file_path, 'w') as f:
    #     print(f"Saving to {file_path} with data_list length {len(data_list)}")
    #     json.dump(data_list, f, indent=4, ensure_ascii=False)

# # Only keep a random 1000 rows
# # import random
# # random.shuffle(data_list)
# # data_list = data_list[:1000]
