import pandas as pd
from src.correlation import Correlation
from model_client import ModelClient
from src.utils import parse_json_block
import json
import os

if __name__ == "__main__":
    input_file = 'data/test_data2.csv'
    output_file = 'data/test_data_output2.json'
    if os.path.exists(output_file):
        with open(output_file, 'r') as json_file:
            processed_data = json.load(json_file)
    else:
        processed_data = []
    processed_row_ids = set([item['id'] for item in processed_data])
    data = pd.read_csv(input_file)
    agent = ModelClient()
    # read each row into a correlation class
    for index, row in data.iterrows():
        if row['id'] in processed_row_ids:
            continue
        original_row = row.to_dict()
        # create a correlation object
        corr = Correlation(row['agg_attr1'][4:] if row['agg_attr1'][:4] == 'avg_' else row['agg_attr1'], row['tbl_name1'], 
                           row['agg_attr2'][4:] if row['agg_attr2'][:4] == 'avg_' else row['agg_attr2'], row['tbl_name2'], 
                           'location', 'census tract')
        user_msg = corr.get_user_msg()
        print(user_msg)
        response, usage = agent.call(user_msg)
        print(usage)
        print(response)
        response = parse_json_block(response)
        pred_coef, pred_std = response['coefficient'], response['standard deviation']
        combined_data = {**original_row, 'predicted_coef': float(pred_coef), 'predicted_std': float(pred_std), 'usage': usage}
        processed_data.append(combined_data)
        # write the data to the output file
        with open(output_file, 'w') as json_file:
            json.dump(processed_data, json_file, indent=4)