import pandas as pd
import os
import re
import json
import csv
import itertools
import numpy as np
from scipy import stats
from constants import TASK_DIR,TASK_DESCRIPTIONS,MODELS,PROMPT_PRICE_MODELS,COMPLETION_PRICE_MODELS

def extract_session_number(folder_name):
    match = re.search(r'session(\d+)', folder_name)
    if match:
        return int(match.group(1))
    return 0

def get_current_card(wcst_data={},current_trial=1):
    for card_name, card_data in wcst_data.items():
        if card_data['trialNumber'] == current_trial:
            return card_data
    raise ValueError(f"No card found for trial {current_trial}")

def extract_session_info(folder_name):
    if folder_name.startswith('WCST_human'):
        pattern = r'WCST_human_subject(\d+)_(\w+)_(\w+)_session(\d+)_(\d+)_(\d+)'
        match = re.match(pattern, folder_name)
        if match:
            subject_id, prompt_type, presentation_mode, session_number, date, time = match.groups()
            return {
                'subject_id': int(subject_id),
                'prompt_type': prompt_type,
                'presentation_mode': presentation_mode,
                'session_number': int(session_number),
                'date': date,
                'time': time
            }
    else:
        match = re.search(r'session(\d+)', folder_name)
        if match:
            return {'session_number': int(match.group(1))}
    return None

def analyze_single_wcst_session(csv_path):
    df = pd.read_csv(csv_path)
    metrics = {}
    # Categories Completed (CC)
    metrics['Categories_Completed'] = df['Category_Completed'].max()
    last_sequence = df[df['Correct_In_Row'] >= 0].iloc[-1]
    if last_sequence['Correct_In_Row'] == 10:
        metrics['Categories_Completed'] = metrics['Categories_Completed'] + 1
    # Errors are identified more explicitly
    # Initialize variables to store rule applications
    # Initialize counts for errors
    perseverative_errors = 0
    non_perseverative_errors = 0
    previous_correct_rules = set()  # Holds the rules that were correct in the last correct response
    # Define a mapping from Current_Rule string to single rule character
    rule_mapping = {
        'color_rule': 'C',
        'shape_rule': 'S',
        'number_rule': 'N'
    }
    
    # Iterate through dataframe to count types of errors
    for index, row in df.iterrows():
        if pd.isna(row['Applied_Rules']):
            current_rules = set()
        else:
            current_rules = set(row['Applied_Rules'])  # The set of rules applied in this response
        if row['Current_Rule'] in rule_mapping:
            current_correct_rule = rule_mapping[row['Current_Rule']]  # The correct rule for this response

        if not row['Is_Correct']:
            # Check if any of the applied rules were part of the previous correct rules
            if previous_correct_rules and (current_rules & previous_correct_rules) and (current_correct_rule not in current_rules):
                perseverative_errors += 1
            else:
                non_perseverative_errors += 1
        
        # Update the previous correct rules if the current response is correct
        if row['Is_Correct'] and row['Correct_In_Row']==10:
            previous_correct_rules = set(current_rules)

    # Perseverative Errors (PE)
    metrics['Perseverative_Errors'] = perseverative_errors
    # Non-perseverative Errors (NPE)
    metrics['Non_Perseverative_Errors'] = non_perseverative_errors
    all_errors = df[df['Is_Correct'] == False]
    metrics['All_Errors'] = len(all_errors)
    
    
    # Trials to Complete First Category (TFC)
    first_category_index = df[df['Category_Completed'] == 1].index.min()
    if pd.notna(first_category_index):
        metrics['Trials_to_First_Category'] = first_category_index + 1
        if last_sequence['Correct_In_Row'] == 10:
            metrics['Trials_to_First_Category'] = metrics['Trials_to_First_Category'] + 1
    else:
        # metrics['Trials_to_First_Category'] = None
        metrics['Trials_to_First_Category'] = first_category_index
    
    # Conceptual Level Responses
    Correct_Sequence = df['Is_Correct'].astype(int).groupby((df['Is_Correct'] == 0).cumsum()).cumsum()
    CLR_Triple = Correct_Sequence.apply(lambda x: x >= 3)
    # CLR_Groups = CLR_Triple & (CLR_Triple.shift(-1) | CLR_Triple.shift(-2))
    total_responses = len(df)
    clr_total = CLR_Triple.sum()
    metrics['Conceptual_Level_Responses'] = round((clr_total / total_responses) * 100 ,2)
    
    # Failure to Maintain Set
    failure_to_maintain = df[(df['Correct_In_Row'] >= 5) & (df['Correct_In_Row'] < 10) & (df['Is_Correct'].shift(-1) == False)]
    metrics['Failure_to_Maintain_Set'] = len(failure_to_maintain)
    
    # Learning to Learn
    category_errors = df.groupby('Category_Completed')['Is_Correct'].apply(lambda x: (~x).sum()).reset_index()
    if len(category_errors) > 2:
        x = category_errors['Category_Completed'][:-1]
        y = category_errors['Is_Correct'][:-1]
        slope, _, _, _, _ = stats.linregress(x, y)
        metrics['Learning_to_Learn'] = round(-slope,2)  # Negative because we want to measure improvement
    else:
        metrics['Learning_to_Learn'] = np.nan
    # return metrics,df_rules
    return metrics



def analyze_wcst(task_name, data_folder,model_name,prompt_type,presentation_mode,out_folder, impairment_type=None):
    all_session_metrics = []
    valid_impairment_types = ["Goal_Maint", "Inhib_Ctrl", "Adapt_Upd"]
    
    if model_name == "human":
        session_folders = [f for f in os.listdir(data_folder) if f.startswith(f'{task_name}_human_subject')]
    else:
        session_folders = [f for f in os.listdir(data_folder) if f.startswith(f'{task_name}_{model_name}_{prompt_type}_{presentation_mode}_session')]
    
    if impairment_type:
        session_folders = [f for f in session_folders if f.endswith(impairment_type)]
    else:
        # session_folders = [f for f in session_folders if not f.endswith(impairment_type)]
        session_folders = [f for f in session_folders if not any(f.endswith(it) for it in valid_impairment_types)]
    
    if model_name == "human":
        session_folders.sort(key=lambda x: extract_session_info(x)['subject_id'])
    else:
        session_folders.sort(key=lambda x: extract_session_info(x)['session_number'])
    
    for session_folder in session_folders:
        csv_path = os.path.join(data_folder, session_folder, 'trial_data.csv')
        # json_path = os.path.join(data_folder, session_folder)
        if os.path.exists(csv_path):
            # session_number = extract_session_number(session_folder)
            session_info = extract_session_info(session_folder)
            # session_metrics,df_rules = analyze_single_wcst_session(csv_path,session_number)
            session_metrics = analyze_single_wcst_session(csv_path)
            # session_metrics['Session'] = session_number
            all_session_metrics.append(session_metrics)
    # Convert to DataFrame for easy statistical analysis
    metrics_df = pd.DataFrame(all_session_metrics)
    
    # Calculate statistical measures
    stats_summary = {}
    for column in metrics_df.columns:
        stats_summary[column] = {
            'mean': round(metrics_df[column].mean(),2),
            'std': round(metrics_df[column].std(skipna=False),2),
            'min': round(metrics_df[column].min(),2),
            'max': round(metrics_df[column].max(),2),
            'median': round(metrics_df[column].median(),2)
        }
    # save_results
    if impairment_type:
        file_name=f'{task_name}_{model_name}_{prompt_type}_{presentation_mode}_{impairment_type}_analysis.csv'
        out_folder = out_folder + "/impairment"
    else:
        file_name=f'{task_name}_{model_name}_{prompt_type}_{presentation_mode}_analysis.csv'
        
    if not os.path.exists(out_folder):
        os.makedirs(out_folder)
    
    file_path = os.path.join(out_folder, file_name)
    metric_mapping = {
        'Mean': 'mean',
        'Standard Deviation': 'std',
        'Minimum': 'min',
        'Maximum': 'max',
        'Median': 'median',
    }
    with open(file_path, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(['Session'] + metrics_df.columns.tolist())  # Left empty column
        # Write data rows
        for index, row in metrics_df.iterrows():
            writer.writerow([index+1] + row.tolist())  # Left empty column for each row
        writer.writerow([])  # Blank row for separation
        # writer.writerow(['Metric', 'Mean', 'Standard Deviation', 'Minimum', 'Maximum', 'Median'])

        for stat in ['Mean', 'Standard Deviation', 'Minimum', 'Maximum', 'Median']:
            row = [stat]
            for column in metrics_df.columns:
                row.append(stats_summary[column][metric_mapping[stat]])
            writer.writerow(row)
    print(f"Combined results saved to {file_path}")
    
    return metrics_df, stats_summary


def analyze_tokens_spent(data_folder,task_name,model_name,prompt_type,presentation_mode,impairment=''):
    session_metrics = []
    total_tokens = 0
    all_tokens = []
    if impairment != '':
        session_folders = [f for f in os.listdir(data_folder) if f.startswith(f'{task_name}_{model_name}_{prompt_type}_{presentation_mode}_session') and 
                   (f.endswith(impairment))]
    else:
        session_folders = [f for f in os.listdir(data_folder) if f.startswith(f'{task_name}_{model_name}_{prompt_type}_{presentation_mode}_session') and
                           not any(f.endswith(im) for im in ['Goal_Maint', 'Inhib_Ctrl', 'Adapt_Upd'])]
    session_folders.sort(key=lambda x: int(x.split('session')[1].split('_')[0]))
    index = 0
    for session_folder in session_folders:
        csv_path = os.path.join(data_folder, session_folder, 'trial_data.csv')
        if os.path.exists(csv_path):
            df = pd.read_csv(csv_path)
            
            session_total_prompt_tokens = df['Prompt_tokens'].sum()
            session_total_completion_tokens = df['Completion_tokens'].sum()
            session_total_tokens = session_total_prompt_tokens + session_total_completion_tokens
            last_trial_tokens = df['Total_tokens'].iloc[-1]
            session_metrics.append({
                # 'Session': index+1,
                'session_total_prompt_tokens': session_total_prompt_tokens,
                'session_total_completion_tokens': session_total_completion_tokens,
                'session_total_tokens': session_total_tokens,
                'Last_Trial_Tokens': last_trial_tokens,
            })
            total_tokens += session_total_tokens
            
            # all_tokens.extend(df['Tokens'].tolist())
        index += 1
    metrics_df = pd.DataFrame(session_metrics)
    # Calculate statistical measures
    stats_summary = {}
    for column in metrics_df.columns:
        stats_summary[column] = {
            'mean': round(metrics_df[column].mean(), 0),
            'std': round(metrics_df[column].std(), 2),
            'min': round(metrics_df[column].min(), 0),
            'max': round(metrics_df[column].max(), 0),
            'median': round(metrics_df[column].median(), 0)
        }
    # return metrics_df, stats_summary, total_tokens
    last_token_avg = int(stats_summary['Last_Trial_Tokens']['mean'])
    trial_token_avg = int(stats_summary['session_total_tokens']['mean'])
    total_token_all_trial = total_tokens
    return last_token_avg, trial_token_avg, total_token_all_trial

def analyze_price_spent(data_folder,task_name,model_name,prompt_type,presentation_mode,impairment=''):
    session_metrics = []
    total_price = 0
    all_tokens = []
    if impairment != '':
        session_folders = [f for f in os.listdir(data_folder) if f.startswith(f'{task_name}_{model_name}_{prompt_type}_{presentation_mode}_session') and 
                   (f.endswith(impairment))]
    else:
        session_folders = [f for f in os.listdir(data_folder) if f.startswith(f'{task_name}_{model_name}_{prompt_type}_{presentation_mode}_session') and
                           not any(f.endswith(im) for im in ['Goal_Maint', 'Inhib_Ctrl', 'Adapt_Upd'])]
    session_folders.sort(key=lambda x: int(x.split('session')[1].split('_')[0]))
    index = 0
    prompt_price = PROMPT_PRICE_MODELS[model_name]
    completion_price = COMPLETION_PRICE_MODELS[model_name]
    for session_folder in session_folders:
        csv_path = os.path.join(data_folder, session_folder, 'trial_data.csv')
        if os.path.exists(csv_path):
            df = pd.read_csv(csv_path)
            # last_trial_tokens = df['Tokens'].iloc[-1]
            # session_total_tokens = df['Tokens'].sum()
            session_total_prompt_tokens = df['Prompt_tokens'].sum()
            session_total_prompt_price = round(session_total_prompt_tokens * prompt_price / 1000000, 2)
            session_total_completion_tokens = df['Completion_tokens'].sum()
            session_total_completion_price = round(session_total_completion_tokens * completion_price / 1000000, 2)
            session_total_price = session_total_prompt_price + session_total_completion_price
            session_metrics.append({
                # 'Session': index+1,
                'session_total_prompt_tokens': session_total_prompt_tokens,
                'session_total_completion_tokens': session_total_completion_tokens,
                'session_total_prompt_price': session_total_prompt_price,
                'session_total_completion_price': session_total_completion_price,
                'session_total_price': session_total_price
            })
            total_price += session_total_price
            # all_tokens.extend(df['Tokens'].tolist())
        index += 1
    metrics_df = pd.DataFrame(session_metrics)
    # Calculate statistical measures
    stats_summary = {}
    for column in metrics_df.columns:
        stats_summary[column] = {
            'mean': round(metrics_df[column].mean(), 2),
            'std': round(metrics_df[column].std(), 2),
            'min': round(metrics_df[column].min(), 0),
            'max': round(metrics_df[column].max(), 0),
            'median': round(metrics_df[column].median(), 0)
        }
    # return metrics_df, stats_summary, total_tokens
    session_total_price_avg = stats_summary['session_total_price']['mean']
    total_price_20_trial = round(total_price, 2)
    return '$'+str(session_total_price_avg), '$'+str(total_price_20_trial)

def analyze_selected_combinations(data_folder, out_folder, task_names=None, model_names=None, prompt_types=None, presentation_modes=None, impairment=None):
    if task_names is None:
        task_names = list(TASK_DESCRIPTIONS.keys())
    if model_names is None:
        model_names = MODELS
    if prompt_types is None:
        prompt_types = ['STA', 'CoT']
    if presentation_modes is None:
        presentation_modes = ['OI', 'OT']
    if impairment is None:
        impairments = ['Goal_Maint', 'Inhib_Ctrl', 'Adapt_Upd']

    # Generate all possible combinations of the input lists
    combinations = list(itertools.product(task_names, model_names, prompt_types, presentation_modes))
    
    results = []
    for combination in combinations:
        task_name, model_name, prompt_type, presentation_mode = combination
        try:
            if task_name == "WCST_without_restriction" and prompt_type == "CoT" and presentation_mode == "OT":
                last_token_avg, trial_token_avg, total_token_all_trial = analyze_tokens_spent(
                data_folder, task_name, model_name, "CoT", "OT")
                session_total_price_avg, total_price_all_trial = analyze_price_spent(
                    data_folder, task_name, model_name, prompt_type, presentation_mode)
            elif task_name == "WCST_without_restriction":
                continue
            else:
                last_token_avg, trial_token_avg, total_token_all_trial = analyze_tokens_spent(
                    data_folder, task_name, model_name, prompt_type, presentation_mode)
                session_total_price_avg, total_price_all_trial = analyze_price_spent(
                    data_folder, task_name, model_name, prompt_type, presentation_mode)
            results.append({
                'Task': task_name,
                'Model': model_name,
                'Prompt_Type': prompt_type,
                'Presentation_Mode': presentation_mode,
                'Last_Token_Average': last_token_avg,
                'Trial_Token_Average': trial_token_avg,
                'Total_Tokens_all_Trials': total_token_all_trial,
                'Session_Total_Price_Avg': session_total_price_avg,
                'Total_Price_all_Trials': total_price_all_trial
            })
        except Exception as e:
            print(f"Error analyzing {combination}: {e}")
    
    for model_n in model_names:
        for imp in impairments:
            task_name = "WCST"
            prompt_type = "CoT"
            presentation_mode = "OT"
            last_token_avg, trial_token_avg, total_token_all_trial = analyze_tokens_spent(
                        data_folder, task_name, model_n, prompt_type, presentation_mode,impairment=imp)
            # Calculate price spent for the current combination
            session_total_price_avg, total_price_all_trial = analyze_price_spent(
                data_folder, task_name, model_n, prompt_type, presentation_mode,impairment=imp)
            results.append({
                'Task': task_name+"_"+imp,
                'Model': model_n,
                'Prompt_Type': prompt_type,
                'Presentation_Mode': presentation_mode,
                'Last_Token_Average': last_token_avg,
                'Trial_Token_Average': trial_token_avg,
                'Total_Tokens_all_Trials': total_token_all_trial,
                'Session_Total_Price_Avg': session_total_price_avg,
                'Total_Price_all_Trials': total_price_all_trial
            })
        
    
    results_df = pd.DataFrame(results)
    
    file_name = f'Tokens_spent_analysis.csv'
    file_path = os.path.join(out_folder, file_name)
    
    # Save the results to CSV
    results_df.to_csv(file_path, index=False)
    return results_df

if __name__ == "__main__":
    data_folder = "./experiment_logs"
    out_folder = "./analyze_results"
    ### Tokens
    
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Gemini-1.5 Pro",'CoT','OT',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Gemini-1.5 Pro",'CoT','OI',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Gemini-1.5 Pro",'STA','OT',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Gemini-1.5 Pro",'STA','OI',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST_without_restriction", data_folder, "Gemini-1.5 Pro",'CoT','OT',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Gemini-1.5 Pro",'CoT','OT',out_folder,"Goal_Maint")
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Gemini-1.5 Pro",'CoT','OT',out_folder,"Inhib_Ctrl")
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Gemini-1.5 Pro",'CoT','OT',out_folder,"Adapt_Upd")
    
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "human",'STA','OI',out_folder)
    
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Claude-3.5 Sonnet",'CoT','OT',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Claude-3.5 Sonnet",'CoT','OI',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Claude-3.5 Sonnet",'STA','OT',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Claude-3.5 Sonnet",'STA','OI',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST_without_restriction", data_folder, "Claude-3.5 Sonnet",'CoT','OT',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Claude-3.5 Sonnet",'CoT','OT',out_folder,"Goal_Maint")
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Claude-3.5 Sonnet",'CoT','OT',out_folder,"Inhib_Ctrl")
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "Claude-3.5 Sonnet",'CoT','OT',out_folder,"Adapt_Upd")
    
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "GPT-4o",'CoT','OT',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "GPT-4o",'CoT','OI',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "GPT-4o",'STA','OT',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "GPT-4o",'STA','OI',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST_without_restriction", data_folder, "GPT-4o",'CoT','OT',out_folder)
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "GPT-4o",'CoT','OT',out_folder,"Goal_Maint")
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "GPT-4o",'CoT','OT',out_folder,"Inhib_Ctrl")
    # wcst_metrics_df, wcst_stats_summary = analyze_wcst("WCST", data_folder, "GPT-4o",'CoT','OT',out_folder,"Adapt_Upd")
    
    ### Price
    
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Gemini-1.5 Pro",'STA','OI')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Gemini-1.5 Pro",'STA','OT')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Gemini-1.5 Pro",'CoT','OI')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Gemini-1.5 Pro",'CoT','OT')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST_without_restriction","Gemini-1.5 Pro",'CoT','OT')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Gemini-1.5 Pro",'CoT','OT',impairment='Goal_Maint')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Gemini-1.5 Pro",'CoT','OT',impairment='Inhib_Ctrl')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Gemini-1.5 Pro",'CoT','OT',impairment='Adapt_Upd')
    
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Claude-3.5 Sonnet",'STA','OI')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Claude-3.5 Sonnet",'STA','OT')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Claude-3.5 Sonnet",'CoT','OI')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Claude-3.5 Sonnet",'CoT','OT')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST_without_restriction","Claude-3.5 Sonnet",'CoT','OT')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Claude-3.5 Sonnet",'CoT','OT',impairment='Goal_Maint')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Claude-3.5 Sonnet",'CoT','OT',impairment='Inhib_Ctrl')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","Claude-3.5 Sonnet",'CoT','OT',impairment='Adapt_Upd')
    
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","GPT-4o",'STA','OI')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","GPT-4o",'STA','OT')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","GPT-4o",'CoT','OI')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","GPT-4o",'CoT','OT')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST_without_restriction","GPT-4o",'CoT','OT')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","GPT-4o",'CoT','OT',impairment='Goal_Maint')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","GPT-4o",'CoT','OT',impairment='Inhib_Ctrl')
    # session_total_price_avg, total_price_20_trial = analyze_price_spent(data_folder,"WCST","GPT-4o",'CoT','OT',impairment='Adapt_Upd')
    
    analyze_selected_combinations(data_folder, out_folder, task_names=["WCST", "WCST_without_restriction"], model_names=["Gemini-1.5 Pro", 'Claude-3.5 Sonnet', "GPT-4o"], prompt_types=None, presentation_modes=None, impairment=None)
    