"""
Use the correlations from bechmark/data/metadata_cleaned.csv
load each row into a correlation class
call the LLM model to get the predicted coefficient and standard deviation
"""

import argparse
import pandas as pd
import os 
from src.correlation import Variable, Correlation
from src.llm_clients.openai_client import OpenAIClient
from src.llm_clients.gemini_client import GeminiClient
from src.priors.gaussian_prior import TruncatedGaussian
from src.priors.mixture_gaussian import GaussianMixture
from src.priors.range_derived_gaussian import RangeGaussian
from src.priors.truncated_gaussian_with_reasoning import TruncatedGaussianReasoning
from priors.kde_prior import KDEPrior
from tqdm import tqdm 

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate correlations using LLM and specified prior model")
    parser.add_argument('--input_file', type=str, default='benchmark/data/metadata_sample_20.csv',
                        help='Path to the input correlations file.')
    parser.add_argument('--output_file', type=str, default=None,
                        help='Path to the output file. If not specified, constructed automatically.')
    parser.add_argument('--model', type=str, default='gpt-4o-mini',
                        help='The LLM model to use.')
    parser.add_argument('--prior', type=str, default='truncated gaussian',
                        help='Prior type to use (e.g., "truncated gaussian").')
    parser.add_argument('--num_iter', type=int, default=1,
                        help='Number of iterations to run.')
    args = parser.parse_args()

    input_file = args.input_file
    
    if 'gpt' in args.model:
        agent = OpenAIClient(model=args.model)
    elif 'gemini' in args.model:
        agent = GeminiClient(model=args.model)
    # Initialize prior based on argument
    if args.prior.lower() == "truncated gaussian":
        prior_model = TruncatedGaussian(agent=agent)
    elif args.prior.lower() == "gaussian mixture":
        prior_model = GaussianMixture(agent=agent)
    elif args.prior.lower() == "truncated gaussian with reasoning":
        prior_model = TruncatedGaussianReasoning(agent=agent)
    elif args.prior.lower() == "range gaussian":
        prior_model = RangeGaussian(agent=agent)
    elif args.prior.lower() == "token prob dist":
        prior_model = KDEPrior(agent=agent)
    else:
        raise ValueError(f"Unsupported prior type: {args.prior}")
    
    # Construct output_file if not specified
    base = os.path.splitext(os.path.basename(input_file))[0]
    safe_model = args.model.replace(" ", "_")
    safe_prior = prior_model.name.replace(" ", "_")

    # run iteration numbers
    file_names = []
    for i in range(args.num_iter):
        if args.output_file is None:
            output_file = os.path.join("outputs", f"{base}_{safe_model}_{safe_prior}_iter_{i}.csv")
        else:
            output_file = os.path.join(args.output_file, f"{base}_{safe_model}_{safe_prior}_iter_{i}.csv")
        file_names.append(output_file)
        print(output_file)
        processed_data = []
        data = pd.read_csv(input_file)
        for index, row in tqdm(data.iterrows(), total=data.shape[0], desc="Processing rows"):
            # if index <= 205:
            #     continue
            # if index != 2:
            #     continue
            # print(row)
            # if index >= 1:
            #     break
            if row['skip'] == True:
                continue
            original_row = row.to_dict()
            pair_id = original_row['pair_id']
            var1_data = {
                "attr": original_row['var1'],
                "table": original_row['dataset'],
                "table_desc": original_row['dataset description'],
                "var_desc": original_row['var1_desc']
            }
            var2_data = {
                "attr": original_row['var2'],
                "table": original_row['dataset'],
                "table_desc": original_row['dataset description'],
                "var_desc": original_row['var2_desc']
            }
            variable1 = Variable(**var1_data)
            variable2 = Variable(**var2_data)
            correlation = Correlation(var1=variable1, var2=variable2)
            # evaluate the prior
            if args.prior.lower() == "truncated gaussian":
                prior = prior_model.get_prior(correlation)
                stats = prior_model.get_stats(r_obs=original_row['r_obs'], r_pred=prior['predicted_coef'], sigma=prior['predicted_std'])
                combined_data = {**original_row, **stats, **prior}
            elif args.prior.lower() == "gaussian mixture":
                prior = prior_model.get_prior(correlation, num_components=10)
                stats = prior_model.get_stats(r_obs=original_row['r_obs'], means=prior['mean_l'], stds=prior['std_l'])
                combined_data = {**original_row, **stats, **prior}
            elif args.prior.lower() == "truncated gaussian with reasoning":
                prior = prior_model.get_prior(correlation)
                stats = prior_model.get_stats(r_obs=original_row['r_obs'], r_pred=prior['predicted_coef'], sigma=prior['predicted_std'])
                combined_data = {**original_row, **stats, **prior}
            elif args.prior.lower() == "range gaussian":
                prior = prior_model.get_prior(correlation)
                stats = prior_model.get_stats(r_obs=original_row['r_obs'], r_pred=prior['predicted_coef'], z_mean=prior['z_mean'], z_std=prior['z_std'])
                combined_data = {**original_row, **stats, **prior}
            elif args.prior.lower() == 'token prob dist':
                prior = prior_model.get_prior(correlation)
                # print(prior)
                stats = prior_model.get_stats(distribution=prior['distribution'], r_obs=original_row['r_obs'])
                combined_data = {**original_row, **stats, **prior}
            else:
                raise ValueError(f"Unsupported prior type: {args.prior}") 
            processed_data.append(combined_data)
            # partial_df = pd.DataFrame([combined_data])
            # if len(processed_data) == 1:
            #     partial_df.to_csv(output_file, index=False)
            # else:
            #     partial_df.to_csv(output_file, index=False, mode='a', header=False)
        # print(output_file)
        df = pd.DataFrame(processed_data)
        print(output_file)
        df.to_csv(output_file, index=False)