import re
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets import Dataset
from typing import Optional, List
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import set_seed
from tqdm import tqdm 
import json
import os 
import logging
import time 
from torch.utils.data.distributed import DistributedSampler
import sys
from torch.distributed import init_process_group  # Add this import
# Set SEED
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
set_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# Set cuDNN for deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set environment variables to control threading for various libraries
# Must be set before importing libraries that initialize threading (e.g., NumPy)
# os.environ["NUMEXPR_MAX_THREADS"] = "64"
# os.environ["OMP_NUM_THREADS"] = "4"
# os.environ["MKL_NUM_THREADS"] = "4"
# os.environ["OPENBLAS_NUM_THREADS"] = "4"
# os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Set up logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

MODEL_DIR = ""
DATA_SPLIT = "train"
ZERO_SHOT = True
OUTPUT_DIR = ""
DATA = "halawi"

def add_idx_column(dataset: Dataset) -> Dataset:
    """
    Adds an 'idx' column to the dataset, storing the original row index.
    """
    return dataset.map(lambda example, idx: {'idx': idx}, with_indices=True)

def extract_final_answer0(generated_text: str) -> Optional[float]:
    """
    Extract the final answer (probability between 0 and 1) from text containing
    a substring like '*0.XX*'. Returns None if not found.
    """
    pattern = re.compile(r'\*(0(\.\d+)?|1(\.0+)?)\*')
    match = pattern.search(generated_text)
    if match:
        # matched string includes the asterisks, e.g. '*0.75*'
        final_str = match.group(0).strip('*')  # remove leading/trailing '*'
        return float(final_str)
    return None

def extract_answer(completion: str) -> Optional[float]:
    """
    Extracts the final answer from the LLM's output.
    """
    match = re.search(r"<answer>(.*?)<\/answer>", completion, re.DOTALL)
    if match is None:
        return None 
    
    answer_text = match.group(1).strip()
    
    try :
        prediction = float(answer_text)
    except :
        return None 
    
    if prediction < 0 or prediction > 1:
        return None 
    
    return prediction

def extract_final_answer(llm_output: str) -> Optional[float]:
    """
    Extracts the first probability prediction from the LLM's output.
    
    The prediction can be:
    - A decimal between 0 and 1, possibly wrapped in asterisks (e.g., *0.75*)
    - A percentage, possibly wrapped in asterisks (e.g., *75%*)
    
    Returns:
        A float between 0 and 1 representing the probability, or None if not found.
    """
    # Define regex patterns for different prediction formats
    patterns = [
        # Pattern for asterisk-wrapped percentage (e.g., *75%*)
        r'\*\s*(\d{1,3}(?:\.\d+)?)\s*%\s*\*',
        # Pattern for standalone percentage (e.g., 75%)
        r'(?<!\w)(\d{1,3}(?:\.\d+)?)\s*%(?!\w)',
        # Pattern for asterisk-wrapped decimal (e.g., *0.75*)
        r'\*\s*(0\.\d+)\s*\*',
        # Pattern for standalone decimal (e.g., 0.75)
        r'(?<!\w)(0\.\d+)(?!\w)',
    ]
    
    matches: List[Tuple[int, float]] = []
    
    for pattern in patterns:
        for match in re.finditer(pattern, llm_output):
            value = match.group(1)
            start_index = match.start()
            try:
                if '%' in match.group(0):
                    # Convert percentage to decimal
                    percentage = float(value)
                    if 0 <= percentage <= 100:
                        decimal = percentage / 100
                        matches.append((start_index, decimal))
                else:
                    # Direct decimal value
                    decimal = float(value)
                    if 0 <= decimal <= 1:
                        matches.append((start_index, decimal))
            except ValueError:
                continue  # If conversion fails, skip to the next match
    
    if not matches:
        return None
    
    # Sort matches based on their position in the text
    matches.sort(key=lambda x: x[0])
    
    # Return the decimal value of the earliest match
    return matches[0][1]

def test_extract_final_prediction():
    test_cases = [
        # Asterisk-wrapped decimal
        ("... *0.75* ...", 0.75),
        # Standalone decimal
        ("The prediction is 0.65 based on the analysis.", 0.65),
        # Asterisk-wrapped percentage
        ("... *80%* ...", 0.80),
        # Standalone percentage
        ("We estimate a 55% chance of success.", 0.55),
        # Multiple predictions, should extract the first one
        ("First prediction: *0.60*, second prediction: 70%.", 0.60),
        # Prediction with surrounding text
        ("After careful consideration, *0.85* seems likely.", 0.85),
        # Invalid percentage (>100%)
        ("The chance is 150%.", None),
        # Invalid decimal (>1)
        ("Probability is 1.5.", None),
        # No valid prediction
        ("There is a high likelihood of success.", None),
        # Mixed formats
        ("Predicted outcome: 0.45 and later revised to *60%*.", 0.45),
        # Spaces around asterisks
        ("... * 0.30 * ...", 0.30),
        ("... * 75 % * ...", 0.75),
    ]
    
    for i, (input_str, expected) in enumerate(test_cases, 1):
        result = extract_final_answer(input_str)
        assert result == expected, f"Test case {i} failed: expected {expected}, got {result}"
        print(f"Test case {i} passed: extracted {result}")

# Run tests
# test_extract_final_prediction()
additional_test_cases = [
    # Percentage with decimal
    ("Estimated probability: 75.5%", 0.755),
    # Percentage with more digits
    ("Chance of occurrence is 5%", 0.05),
    # Decimal without leading zero (invalid, should not match)
    ("Probability is .75.", None),
    # Percentage with trailing text
    ("There is a 90% likelihood of approval.", 0.90),
    # Asterisk-wrapped percentage with no space
    ("Outcome: *65%*", 0.65),
    # Asterisk-wrapped decimal with multiple spaces
    ("... *   0.40   * ...", 0.40),
    # Multiple percentages, should extract the first one
    ("Probabilities: 60%, 70%, and 80%.", 0.60),
    # Number with comma (invalid, should not match)
    ("Chance is 70%, not 80%, as previously thought.", 0.70),
    # Percentage at the end
    ("The success rate is 85%", 0.85),
    # Decimal at the end
    ("Final prediction is 0.95", 0.95),
]

def test_additional_cases():
    for i, (input_str, expected) in enumerate(additional_test_cases, 13):
        result = extract_final_answer(input_str)
        assert result == expected, f"Test case {i} failed: expected {expected}, got {result}"
        print(f"Test case {i} passed: extracted {result}")

# Run additional tests
# test_additional_cases()


def format_forecasting_prompt(
    question: str,
    background: str,
    resolution_criteria: str,
    date_begin: str,
    date_close: str,
    zero_shot: bool = False
) -> str:
    """
    Format the prompt given the row data.
    """
    
    if zero_shot:
        return f"""I will ask you a forecasting question. You have to come up with the best estimate for whether the event asked in the question happens or happened. 
        
Question: {question}
Question Background: {background}
Resolution Criteria: {resolution_criteria}
Question close date: {date_close}

Output your final prediction (a number between 0 and 1) with an asterisk at the beginning and end of the decimal. YOUR FINAL PREDICTION SHOULD STRICTLY BE BETWEEN 0 AND 1. For example, if you believe the answer is 75% likely, you would write *0.75*. MAKE SURE TO FORMAT IT CORRECTLY AND PLACE BETWEEN ASTERISKS.
"""
    
    return f"""Question: {question}
    Question Background: {background}
    Resolution Criteria: {resolution_criteria}
    Question close date: {date_close}

    Instructions:
    1. Given the above question, rephrase and expand it to help you do better answering. Maintain all information in the original question.
    {{{{ Insert rephrased and expanded question. }}}}
    2. Using your own knowledge of the world and topic, provide a few
    reasons why the answer might be no. Rate the strength of each reason.
    {{{{ Insert your thoughts }}}}
    3. Using your knowledge of the world and topic, as well as the information provided, provide a few
    reasons why the answer might be yes. Rate the strength of each reason.
    {{{{ Insert your thoughts }}}}
    4. Aggregate your considerations. Think like a superforecaster (e.g. Nate Silver).
    {{{{ Insert your aggregated considerations }}}}
    5. Output an initial probability (prediction) given steps 1-4. It should be a number BETWEEN 0 and 100. For example, 
    if you are 75% confident the answer is yes, you would write 75.  
    {{{{ Insert initial probability }}}}
    6. Evaluate whether your calculated probability is excessively confident or not confident enough. Also,
    consider anything else that might affect the forecast that you did not before consider (e.g. base rate of
    the event).
    {{{{ Insert your thoughts }}}}
    7. Output your final prediction (a number between 0 and 1) with an asterisk at the beginning and end
    of the decimal. YOUR FINAL PREDICTION SHOULD STRICTLY BE BETWEEN 0 AND 1. For example, if you believe 
    the answer is 75% likely, you would write *0.75*. MAKE SURE TO FORMAT IT CORRECTLY AND PLACE BETWEEN ASTERISKS.
    {{{{ Insert your answer }}}}
    """

def format_forecasting_prompt_old(
    question: str,
    background: str,
    resolution_criteria: str,
    date_begin: str,
    date_close: str
) -> str:
    """
    Format the prompt given the row data.
    """
    
    if ZERO_SHOT:
        return f"""Question: {question}
        Question Background: {background}
        Resolution Criteria: {resolution_criteria}
        Today's date: {date_begin}
        Question close date: {date_close}
        
        Output your final prediction (a number between 0 and 1) with an asterisk at the beginning and end
        of the decimal. YOUR FINAL PREDICTION SHOULD STRICTLY BE BETWEEN 0 AND 1. For example, if you believe 
        the answer is 75% likely, you would write *0.75*. MAKE SURE TO FORMAT IT CORRECTLY AND PLACE BETWEEN ASTERISKS.
        """
    
    return f"""Question: {question}
    Question Background: {background}
    Resolution Criteria: {resolution_criteria}
    Today's date: {date_begin}
    Question close date: {date_close}

    Instructions:
    1. Given the above question, rephrase and expand it to help you do better answering. Maintain all information in the original question.
    {{{{ Insert rephrased and expanded question. }}}}
    2. Using your own knowledge of the world and topic, provide a few
    reasons why the answer might be no. Rate the strength of each reason.
    {{{{ Insert your thoughts }}}}
    3. Using your knowledge of the world and topic, as well as the information provided, provide a few
    reasons why the answer might be yes. Rate the strength of each reason.
    {{{{ Insert your thoughts }}}}
    4. Aggregate your considerations. Think like a superforecaster (e.g. Nate Silver).
    {{{{ Insert your aggregated considerations }}}}
    5. Output an initial probability (prediction) given steps 1-4. It should be a number BETWEEN 0 and 100. For example, 
    if you are 75% confident the answer is yes, you would write 75.  
    {{{{ Insert initial probability }}}}
    6. Evaluate whether your calculated probability is excessively confident or not confident enough. Also,
    consider anything else that might affect the forecast that you did not before consider (e.g. base rate of
    the event).
    {{{{ Insert your thoughts }}}}
    7. Output your final prediction (a number between 0 and 1) with an asterisk at the beginning and end
    of the decimal. YOUR FINAL PREDICTION SHOULD STRICTLY BE BETWEEN 0 AND 1. For example, if you believe 
    the answer is 75% likely, you would write *0.75*. MAKE SURE TO FORMAT IT CORRECTLY AND PLACE BETWEEN ASTERISKS.
    {{{{ Insert your answer }}}}
    """

def load_model_and_tokenizer(model_path: str, model_name : str = None):
    if model_name is None:
        model_name = model_path.rstrip("/").split("/")[-1]
    logger.info(f"Using model_name: {model_name}")

    logger.info(f"Loading model and tokenizer from local directory: {model_path}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            # device_map={"": torch.cuda.current_device()},  # Explicitly set device mapping
            trust_remote_code=True,
            torch_dtype=torch.float16
        )
    except:
        # Fallback if your particular directory structure requires it
        model_path += "/snapshots/model/"
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            # device_map={"": torch.cuda.current_device()},  # Explicitly set device mapping
            trust_remote_code=True,
            torch_dtype=torch.float16
        ) 
        
    return model, tokenizer



def format_superforecasting_prompt(
    question: str,
    background: str,
    resolution_criteria: str,
    date_begin: str,
    date_close: str,
    zero_shot: bool = False
) -> str:
    """
    Format the prompt given the row data.
    """
    
    return f"""
Question: {question}
Question Background: {background}
Resolution Criteria: {resolution_criteria}
Question close date: {date_close}
"""
# Output your final prediction with an asterisk at the beginning and end of the decimal. YOUR FINAL PREDICTION SHOULD STRICTLY BE BETWEEN 0 AND 1 UNDER ALL CIRCUMSTANCES. For example, if you believe the answer is 75% likely, you would write *0.75*. MAKE SURE TO FORMAT IT CORRECTLY AND PLACE BETWEEN ASTERISKS.
# """



def evaluate_model(
    model_name: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    dataset,
    max_new_tokens: int = 8192,
    batch_size: int = 32,
    format_prompt_fn : callable = format_forecasting_prompt,
    max_prompt_length: int = 4096,
    num_generations: int = 8,  # Added parameter for number of generations
):
    """
    Run batched inference with multiple generations per prompt
    """
    
    # **Configure Tokenizer for Left Padding**
    # 1. Set padding_side to 'left'
    tokenizer.padding_side = 'left'
    
    # Make sure pad token is defined
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token # unk_token

    # Create prompts from each row
    prompts = [
        format_prompt_fn(
            question=row["question"],
            background=row["background"],
            resolution_criteria=row["resolution_criteria"],
            date_begin=row["date_begin"],
            date_close=row["date_close"]
        )
        for row in dataset
    ]

    # Actual resolution values (0/1)
    resolutions = dataset["resolution"]
    # convert to numpy array
    resolutions = np.array(resolutions, dtype=float)

    
    # Example DataLoader setup for reproducibility
    def worker_init_fn(worker_id):
        np.random.seed(SEED + worker_id)

    # 3) Build DataLoader (no shuffle to keep chunk order stable)
    def local_collate(samples):
        local_prompts = [ format_prompt_fn(
            question=row["question"],
            background=row["background"],
            resolution_criteria=row["resolution_criteria"],
            date_begin=row["date_begin"],
            date_close=row["date_close"],
            zero_shot=ZERO_SHOT,
        )
        for row in samples
        ]
                    
        messages = local_prompts
        
        try :
            messages = []
            for prompt in local_prompts:
                
                chat = [{ 
                    "role": "user",
                    "content": f"You will be asked a forecasting question. You have to come up with the best estimate for whether the event asked in the question happens or happened. Show your work (reasoning) in <think> </think> tags. And return only the final answer (probability) in <answer> </answer> tags, for example if you think the event asked is 83% likely, then output <answer>0.83</answer>. YOUR FINAL PREDICTION SHOULD STRICTLY BE BETWEEN 0 AND 1. Think step by step inside <think> tags."
                },
                {
                    "role": "user",
                    "content": prompt,
                },
                {
                    "role": "assistant",
                    "content": "Let me solve this step by step.\n<think>"
                }]
        
                # chat = [{"role": "user", "content": "You are an expert superforecaster, familiar with Structured Analytic Techniques as well as Superforecasting by Philip Tetlock and related work. You will be asked a forecasting question. You have to come up with the best estimate (probability) for whether the event asked in the question resolves as true/yes."},
                #         {"role": "user", "content": prompt},
                #         {
                #             "role": "assistant",
                #             "content": "Let me reason about this step by step.\n<think>"
                #         }]
                
                # chat = [{"role": "user", "content": prompt},
                #         {
                #             "role": "assistant",
                #             "content": "Let me reason about this step by step.\n<think>"
                #         }]
                messages.append(tokenizer.apply_chat_template(chat, tokenize=False,  continue_final_message=True))
                
        except Exception as e:
            logger.info(f"Error in tokenizer.apply_chat_template: {e}")
            messages = local_prompts

        
        batched_encodings = tokenizer(messages,
            padding=True,
            truncation=True,
            max_length=max_prompt_length,
            return_tensors="pt"
        )
        
        # Keep idx as a tensor, so we can track them in the model loop
        local_resolutions = [float(row["resolution"]) for row in samples]
        
        # batched_encodings["idx"] = torch.tensor([s["idx"] for s in samples], dtype=torch.long)
        # batched_encodings["resolutions"] = torch.tensor(local_resolutions, dtype=torch.long)
        # batched_encodings["input_lengths"] = torch.tensor([len(prompt) for prompt in local_prompts], dtype=torch.long)
        
        # Remove torch tensor, keep the list as it is
        batched_encodings["resolutions"] = local_resolutions
        batched_encodings["input_lengths"] = [len(prompt) for prompt in local_prompts]
        batched_encodings["idx"] = [s["idx"] for s in samples]
        
        return batched_encodings

    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=local_collate,
        shuffle=False,
        num_workers=0,  # Ensure single worker to limit threading
        # drop_last=False,
        worker_init_fn=worker_init_fn
    )

    # 1) Initialize Accelerator
    accelerator = Accelerator()
    accelerator.state.seed = SEED 
    
    # 4) Prepare model & DataLoader with accelerator
    model, dataloader = accelerator.prepare(model, dataloader)
    # dataloader = accelerator.prepare(dataloader)
    model.eval()

    accelerator.wait_for_everyone()
    # For collecting (idx, prediction) in each process
    local_preds = []
    local_actuals = []
    skips = 0 
    raw_outputs = []
    start_time = time.time()
    skipped = []
    response_lengths = []
    
    # 5) Inference Loop
    all_results = []  # List to store all generation results
    
    for batch in tqdm(dataloader, desc="Generating"):
        batch_resolutions = batch["resolutions"]
        idxs = batch["idx"]
        # For each prompt, generate multiple times
        for gen_idx in range(num_generations):
            with torch.no_grad():
                # Check if model has module attribute before accessing
                if hasattr(model, 'module'):
                    model_to_use = model.module
                else:
                    model_to_use = model
                    
                outputs = model_to_use.generate(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    max_new_tokens=max_new_tokens,
                    do_sample=True,  # Enable sampling for diverse generations
                    temperature=0.6,  # Add some randomness
                    top_p=0.95,
                    pad_token_id=tokenizer.pad_token_id
                )
                outputs = accelerator.pad_across_processes(
                    outputs, dim=1, pad_index=tokenizer.pad_token_id)
            
            batch_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            
            for i in range(len(batch_resolutions)):
                decoded_text = batch_decoded[i]
                actual = batch_resolutions[i]
                
                # Get prompt tokens and completion tokens
                prompt_tokens = len(batch["input_ids"][i])
                completion_tokens = len(outputs[i]) - prompt_tokens
                
                # prompt_end_idx = decoded_text.find("MAKE SURE TO FORMAT IT CORRECTLY AND PLACE BETWEEN ASTERISKS")
                prompt_end_idx = decoded_text.find("Let me solve this step by step.\n<think>")
                answer = decoded_text[prompt_end_idx:]
                
                final_prob = extract_answer(answer)
                skipped = False 
                if final_prob is None:
                    final_prob = 0.5
                    skipped = True 
                # Store result in dictionary
                result = {
                    "model": model_name,
                    "prompt": decoded_text[:prompt_end_idx],  # Original prompt
                    "split": DATA_SPLIT,
                    "data_type": DATA,
                    "idx": idxs[i],
                    "generation_idx": gen_idx,
                    "response": answer,
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": completion_tokens,
                    "final_answer": float(final_prob),
                    "resolution": float(actual),
                    "skipped": skipped,
                }
                
                all_results.append(result)
                
    # Gather results from all processes
    all_results = accelerator.gather_for_metrics(all_results)
    
    accelerator.wait_for_everyone()
    
    # Save results on main process
    if accelerator.is_main_process:
        
        # Calculate questions skipped 
        skipped_questions = len([result for result in all_results if result["skipped"]])
        logger.info(f"Skipped questions: {skipped_questions}")
        
        # Calculate metrics as before using the first generation for each prompt
        first_gens = {}
        for result in all_results:
            if result["generation_idx"] == 0:
                first_gens[result["idx"]] = result["final_answer"]
        
        predictions = [first_gens[idx] for idx in sorted(first_gens.keys())]
        actuals = [[result["resolution"], result["idx"]] for result in all_results if result["generation_idx"] == 0]
        
        # Sort actuals by idx
        actuals = sorted(actuals, key=lambda x: x[1])
        actuals = [x[0] for x in actuals]
        
        predictions = predictions[:len(prompts)]
        actuals = actuals[:len(prompts)]
        
        predictions = np.array(predictions, dtype=float)
        actuals = np.array(actuals, dtype=float)
        
        brier_score = np.mean((predictions - actuals) ** 2)
        predicted_binary = (predictions > 0.5).astype(int)
        accuracy = np.mean(predicted_binary == actuals)
        
        logger.info(f"Brier Score: {brier_score:.4f}")
        logger.info(f"Accuracy:    {accuracy:.4f}")
        
        output_file = f"{OUTPUT_DIR}{model_name}_{DATA_SPLIT}_size_{len(dataset)}_generations_{num_generations}.json"
        # Create parent directories first
        # parent_dir = os.path.dirname(OUTPUT_DIR)
        # if parent_dir:
        #     os.makedirs(parent_dir, exist_ok=True)
        # # Try to create final directory with more permissive mode
        # os.makedirs(OUTPUT_DIR, mode=0o777, exist_ok=True)
        
        with open(output_file, "w") as f:
            json.dump(all_results, f, indent=2)
        
        logger.info(f"Saved {len(all_results)} generations to {output_file}")
        
        return brier_score, accuracy, skipped_questions  # Return 0 for skips since we're using fallback


if __name__ == "__main__":
    import argparse
    from data_utils import *

    parser = argparse.ArgumentParser()
    parser.add_argument('--base_save_dir', default="/fast/XXXX-3/forecasting/evals/manual/halawi/", help="Where to save outputs of the model")
    
    # parser.add_argument('--model_dir', type=str, default="/fast/rolmedo/models/llama-3.1-8b-it", help="Model directory")
    # parser.add_argument('--model', type=str, default="llama-3.1-8b-it", help="Model name")
    
    parser.add_argument('--model_dir', type=str, default="/fast/rolmedo/models/qwen2.5-7b-it", help="Model directory")
    parser.add_argument('--model', type=str, default="None", help="Model name")
    
    parser.add_argument('--batch_size', type=int, default=32, help="Batch size for generation")

    # Add max_new_tokens arg
    parser.add_argument('--max_new_tokens', type=int, default=4096, help="Maximum number of new tokens for generation")
    
    parser.add_argument('--data_split', type=str, default="test", help="Data split to use")
    
    parser.add_argument('--data', type=str, default="halawi", choices=['metaculus', 'halawi'],
                      help="Which dataset to use")
    
    parser.add_argument('--num_generations', type=int, default=1, help="Number of generations to use per prompt")
    
    args = parser.parse_args()
    
    gpu_count = torch.cuda.device_count()
    logger.info(f"Number of GPUs available: {gpu_count}")
    # logger.info(f"Batch size: {args.batch_size}")
    
    MODEL_DIR = args.model_dir
    DATA_SPLIT = args.data_split
    DATA = args.data
    OUTPUT_DIR = args.base_save_dir
    if OUTPUT_DIR[-1] != "/":
        OUTPUT_DIR += "/"
    base_save_dir = args.base_save_dir
    
    if DATA == "halawi":
    # load training data
        dataset = load_halawi_data(split=DATA_SPLIT)

    elif DATA == "metaculus":
        dataset = load_metaculus_data(split=DATA_SPLIT)
    
    logger.info(f"Data split: {DATA_SPLIT}")
    logger.info(f"Data type: {DATA}")
    logger.info(f"Dataset size: {len(dataset)}") 

    # shuffle dataset
    # dataset = dataset.shuffle(seed=SEED)
    # dataset = dataset.select(range(60))
    # logger.info(f"Actual dataset size: {len(train_dataset)}")

    # dataset = add_idx_column(dataset)
    # logger.info(f"Actual dataset size: {len(train_dataset)} Filtered ds size: {len(dataset)}")

    dataset = add_idx_column(dataset)
    batch_size = args.batch_size
    new_tokens = args.max_new_tokens
    logger.info(f"Batch size: {batch_size}")
    logger.info(f"Number of generations: {args.num_generations}")
    logger.info(f"Max new tokens: {new_tokens}")
    logger.info(f"Output directory: {OUTPUT_DIR}")
    logger.info(f"Model directory: {MODEL_DIR}")
    model_name = args.model
    
    # Extract model name from model_dir 
    if args.model == "None":
        model_name = MODEL_DIR.rstrip("/").split("/")[-1]
        # Remove any checkpoint suffix after model name
        if "checkpoint" in MODEL_DIR:
            model_name = MODEL_DIR.rstrip("/").split("/")[-2] + "__" + MODEL_DIR.rstrip("/").split("/")[-1]
        
    logger.info(f"Model name: {model_name}")
    
    model, tokenizer = load_model_and_tokenizer(args.model_dir, model_name)
    evaluate_model(model_name, model, tokenizer, dataset, max_new_tokens=new_tokens, batch_size=batch_size, format_prompt_fn=format_superforecasting_prompt, num_generations=args.num_generations)    
    