import argparse
import pandas as pd
import os 
from src.correlation import Variable, Correlation
from src.llm_clients.openai_client import OpenAIClient
from src.llm_clients.gemini_client import GeminiClient
from prompts.get_context import GET_CONTEXT
from tqdm import tqdm 
import concurrent.futures
import json
from src.utils import parse_json_block
import time

def process_row(row):
    if row['skip'] == True:
        return None
    if pd.notna(row['new_context']):
        return row.to_dict()
    original_row = row.to_dict()
    pair_id = original_row['pair_id']
    var1_data = {
        "attr": original_row['var1'],
        "table": original_row['dataset'],
        "table_desc": original_row['dataset description'],
        "var_desc": original_row['var1_desc']
    }
    var2_data = {
        "attr": original_row['var2'],
        "table": original_row['dataset'],
        "table_desc": original_row['dataset description'],
        "var_desc": original_row['var2_desc']
    }
    variable1 = Variable(**var1_data)
    variable2 = Variable(**var2_data)
    # correlation = Correlation(var1=variable1, var2=variable2)
    agent = GeminiClient(model='gemini-2.0-flash')
    msg = GET_CONTEXT.format(attr1=variable1.attr, attr2=variable2.attr, table=variable1.table, tbl_desc=variable1.table_desc, 
                             var1_desc=variable1.var_desc, var2_desc=variable2.var_desc,
                             r_obs=original_row['r_obs'])
    success = False
    tries = 0
    max_retries = 10
    wait = 1
    while not success and tries < max_retries:
        try:
          print(f"Processing row {pair_id}...")
          resp, usage = agent.call(msg)
          print(resp)
          json_block = parse_json_block(resp)
          original_row['new_context'] = json_block['new_context']
          return {**original_row}
        except Exception as e:
            print(f"Error processing row {pair_id}: {e}, Trying again in {wait} seconds...")
            time.sleep(wait)
            tries += 1
            wait *= 2
    return None   

if __name__ == "__main__":
    processed_data = []
    failed_data = []
    data = pd.read_csv('benchmark/benchmark_4_labeled.csv')
    output_file = 'benchmark/benchmark_4_labeled.csv'
    with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
        future_to_row = {executor.submit(process_row, row): row for _, row in data.iterrows()}
        for future in tqdm(concurrent.futures.as_completed(future_to_row), total=len(future_to_row), desc="Processing rows in parallel"):
            try:
                row = future_to_row[future]
                result = future.result()
                if result is not None:
                    processed_data.append(result)
                else:
                    failed_data.append(row['pair_id'])
                if len(processed_data) % 10 == 0:
                    df = pd.DataFrame(processed_data)
                    df.to_csv(output_file, index=False)
            except Exception as e:
                failed_data.append(row['pair_id'])
                raise
                print(f"Error processing row: {e}")
        # results = list(tqdm(executor.map(process_row, [row for _, row in data.iterrows()]),
        #                     total=data.shape[0], desc="Processing rows in parallel"))
    df = pd.DataFrame(processed_data)
    df.to_csv(output_file, index=False)
    # dump the failed data to a file
    failed_data_file = os.path.splitext(output_file)[0] + "_failed.json"
    # write it to a json file
    with open(failed_data_file, "w", encoding="utf-8") as f:
        json.dump(failed_data, f, indent=2)
