import datasets
import numpy as np 
import datetime


def load_halawi_data(split="train", raw=False):
    path = "YuehHanChen/forecasting"
    if raw:
        path += "_raw"
        
    ds = datasets.load_dataset(path)[split]
    # print(ds.column_names)
    if raw:
        
        # Only keep rows with question_type == BINARY or binary
        ds = ds.filter(lambda x: x["question_type"].lower() == "binary")
        
        # Only keep rows with resolution == 1 or 1.0 or 0 or 0.0 (in str)
        # ds = ds.filter(lambda x: x["resolution"] in ["1.0", "0.0"]) 
        ds = ds.filter(lambda x: x["resolution"] in ["1", "1.0", "0", "0.0"])
    
    # Add prompt column
    ds = ds.map(lambda x: {"text": create_retreived_prompt(
        x["question"], 
        x["background"], 
        x["resolution_criteria"], 
        x["date_begin"], 
        x["date_close"],
        []
    )})
    return ds


def load_cladder():
    path = "causal-nlp/CLadder"
        
    ds = datasets.load_dataset(path)['full_v1.5_default']
    
    # print column names
    print("Column names: ", ds.column_names)
    
    return ds



def load_metaculus_data(split="train", nr_forecasters=1):
    path = "nikhilchandak/metaculus-binary"
    ds = datasets.load_dataset(path)["train"]

    # date_resolve_at, date_begin, date_close, nr_forecasters 
    # Only keep rows with 
    
    # If split is train, only keep rows with date resolve at before June 30, 2024
    if split == "train":
        # ds = ds.filter(lambda x: x["date_resolve_at"] < "2024-06-30")
        ds = ds.filter(lambda x: x["date_begin"] < "2024-06-30")
        

    # If split is test, only keep rows with date resolve at after July 1 2024
    if split == "test":
        ds = ds.filter(lambda x: x["date_begin"] >= "2024-06-30")
        
        # Only keep rows with nr_forecasters > 10
        ds = ds.filter(lambda x: x["nr_forecasters"] >= nr_forecasters)

    # Add prompt column
    ds = ds.map(lambda x: {"text": create_retreived_prompt(
        x["question"], 
        x["background"], 
        x["resolution_criteria"], 
        x["date_begin"], 
        x["date_close"],
        []
    )})
    return ds

def load_menge_data(split="validation", data_type="binary"):
    path = "/fast/XXXX-3/forecasting/datasets/menge/" + data_type + "_" + split + ".json"
        
    # Load dataset
    ds = datasets.Dataset.from_json(path)
    
    # Print column names 
    print("Menge Column names: ", ds.column_names)
    
    return ds


def load_manifold_data(split="validation", nr_forecasters=1):
    path = "/fast/XXXX-3/forecasting/datasets/manifold"
    if split == "distill":
        path += "/binary_mini.json"
    elif split == "test" or split == "validation":
        path += "/manifold_binary_validation_set.json"
        
    # Load dataset
    ds = datasets.Dataset.from_json(path)
    
    # Print column names 
    print("Manifold Column names: ", ds.column_names)
    
    # Apply same filtering as metaculus data
    # if split == "train":
    #     ds = ds.filter(lambda x: x["date_resolve_at"] < "2024-06-30")
    
    # if split == "test":
    #     ds = ds.filter(lambda x: x["date_begin"] >= "2024-05-30")
        # ds = ds.filter(lambda x: x["nr_forecasters"] >= nr_forecasters)
        
    return ds

def filter_halawi_data(ds, begin_date="2023-01-01", end_date="2023-06-01"):
    useful_subset = ds.filter(lambda x: x["date_begin"] > begin_date and x["date_resolve_at"] < end_date)
    return useful_subset



def load_infinitegames_data(split="train", nr_forecasters=1):
    if "balanced" in split:
        path = "/fast/XXXX-3/forecasting/datasets/infinitegames/binary_balanced_test.json"
    else:
        path = "/fast/XXXX-3/forecasting/datasets/infinitegames/binary_test.json"
    
    
    # path = "/fast/XXXX-3/forecasting/datasets/infinitegames/binary_balanced_test.json"
    ds = datasets.Dataset.from_json(path)
    
    print("Column names: ", ds.column_names)
    
    # ds = ds.select(range(10))
    # resolution value counts
    # print(np.unique(ds["resolution"], return_counts=True))
        
    return ds


def load_paleka(split="spanned"):
    path = "/fast/XXXX-3/forecasting/datasets/paleka/20240701_20240831"
    if split == "spanned" or "gpt4o" in split:
        path += "_gpt-4o_spanned_resolved"
        
    path += ".jsonl"
        
    ds = datasets.Dataset.from_json(path)
    
    # print column names
    print("Paleka Column names: ", ds.column_names)
    
    print(f"Length of dataset: {len(ds)}")
    
    # Resolution value counts
    print("Resolution value counts: ", np.unique(ds["resolution"], return_counts=True))
    print("Resolution date value counts: ", np.unique(ds["resolution_date"], return_counts=True))
    # Transform resolution column from True/False to 1/0
    # ds = ds.map(lambda x: {"resolution": 1 if x["resolution"] == True else 0})
    
    # print first 10 rows
    # for i, row in enumerate(ds):
    #     print(row["title"])
    #     print(row["body"])
    #     print(row["resolution_date"])
    #     print(row["resolution"])
    #     print("-"*100)
    #     if i > 5:
    #         break
    
    # Add prompt column
    ds = ds.map(lambda x: {"prompt": f"Question: {x['title']}\nResolution Criteria: {x['body']}"})
        
    # pretty print the prompt iteratively 
    for i, row in enumerate(ds):
        print(row["prompt"])
        print("-"*100)
        if i > 5:
            break
    
    # Create a train subset for questions with resolution_date before 2024-08-01
    # create datetime object for 2024-08-01
    date_2024_08_01 = datetime.datetime(2024, 8, 1)
    train = ds.filter(lambda x: x["resolution_date"] < date_2024_08_01)
    test = ds.filter(lambda x: x["resolution_date"] >= date_2024_08_01)
    
    # print length of the train and test subsets
    print(f"Length of train subset: {len(train)}")
    print(f"Length of test subset: {len(test)}")
    
    # Resolution value counts
    print("Train Resolution value counts: ", np.unique(train["resolution"], return_counts=True))
    print("Test Resolution value counts: ", np.unique(test["resolution"], return_counts=True))
    
    return train, test

def load_retreived_data(split="train", data_type="retrieval_metaculus", nr_forecasters=1):
    prefix = "/fast/XXXX-11/forecasting/news/retrieval/"
    path = prefix + data_type + "/"
    # path = "/fast/XXXX-11/forecasting/news/retrieval/metaculus-binary_apnews_7_365/"
    
    # Load the entire dataset
    dataset = datasets.load_from_disk(path)
    
    # If the dataset has splits and a specific split is requested
    if hasattr(dataset, 'keys') and split in dataset:
        print("Split found in dataset: ", split)
        dataset = dataset[split]

    ds = dataset
    print("Length before split: ", len(ds))
    # If split is train, only keep rows with date resolve at before June 30, 2024
    if split == "train":
        # ds = ds.filter(lambda x: x["date_resolve_at"] < "2024-06-30")
        ds = ds.filter(lambda x: x["date_begin"] < "2024-06-30")
        

    # If split is test, only keep rows with date resolve at after July 1 2024
    if split == "test":
        ds = ds.filter(lambda x: x["date_begin"] >= "2024-06-30")
        
        # Only keep rows with nr_forecasters > 10
        ds = ds.filter(lambda x: x["nr_forecasters"] >= nr_forecasters)
    
    
    # keep only rows with date_close between 2018-01-01 and 2021-12-31
    ds = ds.filter(lambda x: x["date_close"] >= "2018-01-01" and x["date_close"] <= "2021-12-31")
    
    print("Length after split: ", len(ds))
    # Filter columns for which retrieved_articles is empty
    ds = ds.filter(lambda x: len(x["retrieved_articles"]) >= 3)
    
    # print column names
    print("Column names: ", ds.column_names)
    # print first 10 rows
    # print(dataset[:10])
    # print length of the dataset
    print("Length of dataset only keeping rows with retrieved articles: ", len(ds))
    
    # create prompt for each row
    # Create a new column with the prompts
    # Check if we should use retrieval based on data_type
    use_retrieval = "without" not in data_type
    print(f"Using retrieval: {use_retrieval}")
    
    def create_prompt_for_row(row):
        return create_retreived_prompt(
            row["question"], 
            row["background"], 
            row["resolution_criteria"], 
            row["date_begin"], 
            row["date_close"], 
            row["retrieved_articles"] if use_retrieval else []
        )
    
    # Apply the function to create prompts for all rows
    ds = ds.map(lambda row: {"prompt": create_prompt_for_row(row)})
    # ds = ds.select(range(5,6))
    
    # pretty print the prompt iteratively 
    # for i, row in enumerate(ds):
    #     print(row["prompt"])
    #     print("-"*100)
    
    return ds


def create_retreived_prompt(
    question: str,
    background: str,
    resolution_criteria: str,
    date_begin: str,
    date_close: str,
    retrieved_articles: list[dict] = []
) -> str:
    """
    Format the prompt given the row data.
    """
    
    prefix = f"""
Question: {question}
Question Background: {background}
Resolution Criteria: {resolution_criteria}
Question close date: {date_close}"""

    if len(retrieved_articles) > 0:
        prefix += "\n\nWe have retrieved the following articles for this question from cleaned Common Crawl news data using the BM25 ranking algorithm (so it is possible that some of them might not be too relevant to the question):"
        for article in retrieved_articles:
            prefix += f"\n\nTitle: {article['title']}"
            prefix += f"\nURL: {article['url']}"
            prefix += f"\nDate Published: {article['date_publish']}"
            prefix += f"\nContent: {article['maintext']}"
    else:
        prefix += "\n"
            
    return prefix

if __name__ == "__main__":
    # ds = load_halawi_data("train", raw=True)
    
    # # print column names 
    # print(ds.column_names)
    # resolutions = ds["resolution"]
    
    # # Count number of 0s and 1s in the resolution column
    # print(np.unique(resolutions, return_counts=True))    
    
    # ds = load_metaculus_data(split="train")
    # ds = load_manifold_data(split="test")
    # print first 10 rows of the dataset
    # print(ds[:10])
    # print length of the dataset of a column
    # ds = load_cladder() 
    # print(len(ds["question"]))
    ds = load_paleka(split="spanned")