from src.llm_clients.openai_client import OpenAIClient
import time
from src.utils import parse_json_block
from tqdm import tqdm 
import concurrent.futures
import time
import pandas as pd
import json

message = """
You are given a column name and the context in which it appears. Your task is to judge whether the column name clearly and accurately conveys its meaning.

Column Name: "{col_name}"
Dataset Name: "{dataset_name}"
Dataset Description: "{dataset_desc}"

Please respond in JSON:
```json
{{
  "valid": "<yes or no>",
}}
```
"""

agent = OpenAIClient(model='gpt-4.1')

def process_row(row):
    success = False
    tries = 0
    max_retries = 10
    original_row = row.to_dict()
    while not success and tries < max_retries:
        msg = message.format(col_name=original_row['variable'], dataset_name=original_row['dataset'], dataset_desc=original_row['dataset description'])
        try:
            resp, _ = agent.call(msg, temp=1)
            try:
                resp_json = json.loads(resp)
            except json.JSONDecodeError as e:
                resp_json = parse_json_block(resp)
            answer = resp_json['valid']
            success = True
            return {**original_row, 'undestood': answer}
        except Exception as e:
            print(f"Error processing row {original_row['variable'], original_row['dataset']}: {e}, Trying again in {wait} seconds...")
            time.sleep(wait)
            tries += 1
            wait *= 2

processed_data = []
failed_data = []
data = pd.read_csv("benchmark/kaggle_variables.csv")
output_file = "benchmark/kaggle_variables_labeled2.csv"
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
    future_to_row = {executor.submit(process_row, row): row for _, row in data.iterrows()}
    for future in tqdm(concurrent.futures.as_completed(future_to_row), total=len(future_to_row), desc="Processing rows in parallel"):
        try:
            row = future_to_row[future]
            result = future.result()
            if result is not None:
                processed_data.append(result)
            if len(processed_data) % 10 == 0:
                df = pd.DataFrame(processed_data)
                df.to_csv(output_file, index=False)
        except Exception as e:
            print(f"Error processing row: {e}")
    df = pd.DataFrame(processed_data)
    df.to_csv(output_file, index=False)

