import os
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import re
import argparse


def get_llm_choice(model, tokenizer, base_prompt, max_retries=2, temperature=0.7, top_p=0.9, max_new_tokens=50):
    """
    Queries a BASE language model using a 2-option symbolic few-shot prompt.
    """
    few_shot_examples = """
Follow this format and choose either A or B based on the options provided.
Question:
A: A P% chance to win $X and a (100-P)% chance to win $Y.
B: A Q% chance to win $Z and a (100-Q)% chance to win $W.

Answer: B

Question:
A: A I% chance to win $S and a (100-I)% chance to win $T.
B: A J% chance to win $M and a (100-J)% chance to win $N.

Answer: A
"""
    full_prompt = (
        f"{few_shot_examples.strip()}\n\n"
        f"{base_prompt}\n\n"
        "Answer:"
    )

    patterns = [
        r"^\s*([ab])", 
        r"(?:choice|option|answer|select|selection)\s*(?:is|would be)?:?\s*'?\"?([ab])\b",
        r"['\"\(]([ab])['\"\)]",
        r"\b([ab])\b"
    ]

    for attempt in range(max_retries + 1):
        try:
            input_ids = tokenizer(full_prompt, return_tensors="pt").to(model.device)
            outputs = model.generate(
                **input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                eos_token_id=tokenizer.eos_token_id
            )
            response_text = tokenizer.decode(outputs[0][input_ids.input_ids.shape[1]:], skip_special_tokens=True).strip().lower()

            for pattern in patterns:
                match = re.search(pattern, response_text)
                if match:
                    return match.group(1).upper()
            
            if attempt < max_retries:
                print(f"      > Warning: Could not parse choice from '{response_text}'. Retrying ({attempt + 1}/{max_retries})...")
                time.sleep(1)

        except Exception as e:
            print(f"      > Error during model generation: {e}. Retrying...")
            if attempt >= max_retries:
                return "generation_error"

    print(f"      > Error: All retries failed. Final response was '{response_text}'.")
    return "parsing_failed"


def main():
    parser = argparse.ArgumentParser(description="Query a base LLM with financial choice questions.")
    parser.add_argument("--model_id", type=str, default="", help="Model ID from Hugging Face Hub.")
    parser.add_argument("--input_csv", type=str, default='', help="Path to the input CSV file with questions.")
    parser.add_argument("--output_csv", type=str, default=None, help="Path to save the output CSV. If omitted, a name is generated.")
    parser.add_argument("--gpu_id", type=str, default="0", help="GPU index to use.")
    parser.add_argument("--max_retries", type=int, default=5, help="Number of times to retry if a valid choice is not generated.")
    parser.add_argument("--max_new_tokens", type=int, default=50, help="Maximum new tokens for the model to generate.")
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    
    if args.output_csv is None:
        input_basename = os.path.splitext(os.path.basename(args.input_csv))[0]
        safe_model_name = args.model_id.replace("/", "_")
        args.output_csv = f"{input_basename}_{safe_model_name}_choices.csv"

    # --- Model and Tokenizer Setup ---
    print("Setting up the model and tokenizer...")
    if not torch.cuda.is_available():
        raise SystemExit("Error: A CUDA-enabled GPU is required.")
    
    print(f"Loading model '{args.model_id}' on GPU: {torch.cuda.get_device_name(0)}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    model = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        torch_dtype="auto",
        device_map="auto",
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print("Model loaded successfully.")

    try:
        print(f"\nLoading questions from '{args.input_csv}'...")
        df = pd.read_csv(args.input_csv)
        print(f"Found {len(df)} questions to process.")

        successful_rows = []
        for index, row in df.iterrows():
            print(f"   - Querying Question {row['question_id']}/{len(df)}...")
            start_time = time.time()
            
            llm_choice = get_llm_choice(
                model, tokenizer, row['prompt_text'], 
                max_retries=args.max_retries, 
                max_new_tokens=args.max_new_tokens
            )
            
            duration = time.time() - start_time
            
            if llm_choice in ['A', 'B']:
                print(f"     > Received valid choice: '{llm_choice}' in {duration:.2f}s")
                new_row = row.to_dict()
                new_row['llm_choice'] = llm_choice
                successful_rows.append(new_row)
            else:
                print(f"     > Skipping question {row['question_id']} due to parsing/generation failure.")

        print("\n===== Querying Complete! =====")

        if not successful_rows:
            print("Warning: No successful choices were recorded. The output file will be empty.")
            final_df = pd.DataFrame(columns=list(df.columns) + ['llm_choice'])
        else:
            final_df = pd.DataFrame(successful_rows)

        final_df.to_csv(args.output_csv, index=False, encoding='utf-8-sig')

        print(f"\nAll {len(final_df)} successful results saved to '{args.output_csv}'")
        print("\nPreview of the final data:")
        print(final_df[['question_id', 'llm_choice']].head())

    except FileNotFoundError:
        print(f"\n--- ERROR: Input file '{args.input_csv}' not found. ---")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

    print("\n--- Script Finished ---")

if __name__ == "__main__":
    main()