import argparse
import pandas as pd
import os 
from src.correlation import Variable, Correlation
from src.llm_clients.openai_client import OpenAIClient
from src.llm_clients.gemini_client import GeminiClient
from src.priors.gaussian_prior import TruncatedGaussian
from src.priors.kde_prior import KDEPrior
from tqdm import tqdm 
import concurrent.futures
import json
import time

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Elicit various types of correlation priors from LLMs.")
    parser.add_argument('--input_file', type=str, default='benchmark/real_world_correlations.csv',
                        help='Path to the input correlations file.')
    parser.add_argument('--output_dir', type=str, default=None,
                        help='Path to the output directory.')
    parser.add_argument('--model', type=str, default='gpt-4o',
                        help='The LLM model to use.')
    parser.add_argument('--prior', type=str, default='lc_prior',
                        help='Prior type to use. Options: "gaussian_prior", "kde_prior", "lc_prior"')
    parser.add_argument('--num_iter', type=int, default=1,
                        help='Number of iterations to run.')
    parser.add_argument('--ref_file', type=str, default=None,
                        help= 'Path to a previous output file for lookup; any correlation '
                                'already in this file will be reused instead of being reprocessed.')
    parser.add_argument('--workers', type=int, default=2,
                        help='The number of workers to use for parallel processing.')

    args = parser.parse_args()

    input_file = args.input_file
    
    if 'gpt' in args.model or 'mini' in args.model:
        agent = OpenAIClient(model=args.model)
    elif 'gemini' in args.model:
        agent = GeminiClient(model=args.model)
    # Initialize prior based on argument
    if args.prior.lower() == "gaussian_prior":
        prior_model = TruncatedGaussian(agent=agent)
    elif args.prior.lower() == "lc_prior" or args.prior.lower() == "kde_prior":
        prior_model = KDEPrior(agent=agent) # they use the same way to get logits from LLMs
        # prior_model.semantic_grounding_and_disambiguate = True
        # prior_model.use_new_context = True
    else:
        raise ValueError(f"Unsupported prior type: {args.prior}")
    
    # Construct output_file if not specified
    base = os.path.splitext(os.path.basename(input_file))[0]
    safe_model = args.model.replace(" ", "_")
    safe_prior = args.prior.replace(" ", "_")

    # run iteration numbers
    file_names = []
    for i in range(args.num_iter):
        if args.output_dir is None:
            output_file = os.path.join("outputs", f"{base}_{safe_model}_{safe_prior}_iter_{i}.csv")
        else:
            if not os.path.exists(args.output_dir):
                os.makedirs(args.output_dir)
            output_file = os.path.join(args.output_dir, f"{base}_{safe_model}_{safe_prior}_iter_{i}.csv")
        file_names.append(output_file)
        print(output_file)
        processed_data = []
        failed_data = []
        data = pd.read_csv(input_file)
        if args.ref_file is not None:
            ref_df = pd.read_csv(args.ref_file)
        def process_row(row):
            # if row['skip'] == True:
            #     return None
            original_row = row.to_dict()
            pair_id = original_row['pair_id']
            # if pair_id can be found in the reference file, retrieve that row
            if args.ref_file is not None:
                # check if pair_id exists in the reference file
                ref_row = ref_df.loc[ref_df['pair_id'] == pair_id]
                if not ref_row.empty:
                    return ref_row.iloc[0].to_dict()
            new_context = None
            # check if the prior is token prob dist
            if prior_model.name == "token_prob_dist" and prior_model.use_new_context:
                new_context = original_row['new_context']
            var1_data = {
                "attr": original_row['var1'],
                "table": original_row['dataset'],
                "table_desc": original_row['dataset description'],
                "var_desc": original_row['var1_desc']
            }
            var2_data = {
                "attr": original_row['var2'],
                "table": original_row['dataset'],
                "table_desc": original_row['dataset description'],
                "var_desc": original_row['var2_desc']
            }
            variable1 = Variable(**var1_data)
            variable2 = Variable(**var2_data)
            correlation = Correlation(var1=variable1, var2=variable2, new_context=new_context)
            success = False
            tries = 0
            max_retries = 10
            wait = 1
            while not success and tries < max_retries:
                try:
                    if args.prior.lower() == "gaussian_prior":
                        prior = prior_model.get_prior(correlation)
                        stats = prior_model.get_stats(r_obs=original_row['r_obs'], r_pred=prior['predicted_coef'], sigma=prior['predicted_std'])
                    elif args.prior.lower() == 'kde_prior' or args.prior.lower() == 'lc_prior':
                        prior = prior_model.get_prior(correlation)
                        stats = prior_model.get_stats(distribution=prior['distribution'], r_obs=original_row['r_obs'])
                    else:
                        raise ValueError(f"Unsupported prior type: {args.prior}")
                    success = True
                    return {**original_row, **stats, **prior}
                except Exception as e:
                    print(f"Error processing row {pair_id}: {e}, Trying again in {wait} seconds...")
                    time.sleep(wait)
                    tries += 1
                    wait *= 2
            return None        

        # Process in parallel
        with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
            future_to_row = {executor.submit(process_row, row): row for _, row in data.iterrows()}
            for future in tqdm(concurrent.futures.as_completed(future_to_row), total=len(future_to_row), desc="Processing rows in parallel"):
                try:
                    row = future_to_row[future]
                    result = future.result()
                    if result is not None:
                        processed_data.append(result)
                    else:
                        failed_data.append(row['pair_id'])
                    if len(processed_data) % 10 == 0:
                        df = pd.DataFrame(processed_data)
                        df.to_csv(output_file, index=False)
                except Exception as e:
                    failed_data.append(row['pair_id'])
                    print(f"Error processing row: {e}")
        df = pd.DataFrame(processed_data)
        df.to_csv(output_file, index=False)
        # dump the failed data to a file
        failed_data_file = os.path.splitext(output_file)[0] + "_failed.json"
        # write it to a json file
        with open(failed_data_file, "w", encoding="utf-8") as f:
            json.dump(failed_data, f, indent=2)
