import numpy as np
from openai import OpenAI
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import StandardScaler
from scipy.stats import entropy
import pandas as pd
import re
import io
from sympy import O
import os
from causallearn.utils.KCI.KCI import KCI_CInd  
import argparse
from tenacity import retry, stop_after_attempt, wait_exponential



def read_observed_data(csv_file_path):
    df = pd.read_csv(csv_file_path)
    
    treatment = df['treatment']
    
    y_factual = df['y_factual']
    
    y_counterfactual = df['y_counterfactual']
    
    
    y_0 = df['y_0']
    
    y_1 = df['y_1']
    
    x_features = df.drop(columns=['treatment', 'y_factual', 'y_counterfactual', 'y_0', 'y_1'])
    
    
    return x_features.to_csv(index=False, header=True, sep=','), y_factual.to_frame().to_string(index=False, header=True), treatment.to_frame().to_string(index=False, header=True),\
        y_counterfactual.to_frame().to_string(index=False, header=True), y_0.to_frame().to_string(index=False, header=True), y_1.to_frame().to_string(index=False, header=True),
    
@retry(stop=stop_after_attempt(100), wait=wait_exponential(multiplier=1, min=2, max=300))
def Counterfactual_Query(prefix_prompt, obs_x, obs_y, obs_t):
    # Combine the CSV content with the prompt
    if args.data == 'Twins':
        full_prompt = f"{prefix_prompt}\n \
        Your Task: \n \
        The values of confounders for N={N_individual} individuals are given by {obs_x}. \
        The values of corresponding binary treatment for N={N_individual} individuals are given by {obs_t}. \
        The values of corresponding observed outcomes for N={N_individual} individuals are given by {obs_y}. \
        Each row represents the observed covariates, treatment and outcomes of an individual, respectively. \n\
        For each of the N={N_individual} twins, please infer the values of the counterfactual outcome values corresponding to the alternative value of treatment based on all the observed data and your world knowledge. \n\n\
        Please format your response EXACTLY as follows:\n\
        \n\
        Overall Explanation: [Explain your general approach to determining these counterfactual outcomes]\n\
        \n\
        Values:\n\
        Individual 1: Y=X\n\
        Individual 2: Y=X\n\
        ...\n\
        Individual {N_individual}: Y=X\n\
        \n\
        Summary: [Summarize key patterns or factors that influenced your counterfactual outcome assignments]\n\
        \n\
        Where X is either 0 or 1 indicating survival (0) or mortality (1).\n\
        IMPORTANT: You MUST provide exactly {N_individual} counterfactual outcomes, one for each individual.\
        Display COMPLETE results for ALL {N_individual} individuals without omission."
    else:
        full_prompt = f"{prefix_prompt}\n \
        Your Task: \n \
        The values of confounders for N={N_individual} individuals are given by {obs_x}. \
        The values of corresponding binary treatment for N={N_individual} individuals are given by {obs_t}. \
        The values of corresponding observed outcomes for N={N_individual} individuals are given by {obs_y}. \
        Each row represents the observed covariates, treatment and outcomes of an individual, respectively. \n\
        For each of the N={N_individual} twins, please infer the values of the counterfactual outcome values corresponding to the alternative value of treatment based on all the observed data and your world knowledge. \n\n\
        Please format your response EXACTLY as follows:\n\
        \n\
        Overall Explanation: [Explain your general approach to determining these counterfactual outcomes]\n\
        \n\
        Values:\n\
        Individual 1: Y=X\n\
        Individual 2: Y=X\n\
        ...\n\
        Individual {N_individual}: Y=X\n\
        \n\
        Summary: [Summarize key patterns or factors that influenced your counterfactual outcome assignments]\n\
        \n\
        Where X is either 0 or 1 indicating unemployment (0) or employment (1).\n\
        IMPORTANT: You MUST provide exactly {N_individual} counterfactual outcomes, one for each individual.\
        Display COMPLETE results for ALL {N_individual} individuals without omission."
    while True:
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": full_prompt}],
            temperature=temperature
        )
        text = response.choices[0].message.content

        print(text)
        
        individual_lines = []
        for line in text.split('\n'):
            if line.strip().startswith('Individual '):
                individual_lines.append(line.strip())
        
        if len(individual_lines) != N_individual:
            print(f'Individual count mismatch, expected {N_individual}, got {len(individual_lines)}, retrying...')
            continue
            
        counterfactual_values = []
        parsing_error = False
        
        for i, line in enumerate(individual_lines):

            match = re.search(r'Y=(\d+)', line)
            if not match:
                print(f'Cannot parse counterfactual value for individual {i+1}, retrying...')
                parsing_error = True
                break
                
            try:
                y_value = int(match.group(1))
                counterfactual_values.append(y_value)
            except ValueError:
                print(f'Numeric parsing error for individual {i+1}, retrying...')
                parsing_error = True
                break
        
        if parsing_error:
            continue
            
        if len(counterfactual_values) == N_individual:
            print(f"Successfully extracted counterfactual values for {N_individual} individuals")
            break
        else:
            print(f'Extracted {len(counterfactual_values)} values, expected {N_individual}, retrying...')
    
    counterfactual_Y = 'y_counterfactual\n' + '\n'.join(map(str, counterfactual_values))
    return counterfactual_Y

@retry(stop=stop_after_attempt(100), wait=wait_exponential(multiplier=1, min=2, max=300))
def Augmentation_Query(prefix_prompt, obs_x, obs_t, obs_y):
    confounder_info = Propose_Confounder(prefix_prompt, obs_x, obs_t, obs_y)
    cname, ctype = confounder_info['name'], confounder_info['type']
    value_description = confounder_info.get('value_description', '')  
    
    if ctype == "continuous":
        sampled_values = Generate_Continuous_Confounder(prefix_prompt, obs_x, obs_t, obs_y, cname, value_description)
    elif ctype == "discrete":
        sampled_values = Generate_Discrete_Confounder(prefix_prompt, obs_x, obs_t, obs_y, cname, value_description)
    elif ctype == "binary":
        sampled_values = Generate_Binary_Confounder(prefix_prompt, obs_x, obs_t, obs_y, cname, value_description)
    else:
        raise ValueError(f"Unrecognized confounder type: {ctype}")
    
    return cname, sampled_values

@retry(stop=stop_after_attempt(100), wait=wait_exponential(multiplier=1, min=2, max=300))
def Propose_Confounder(prefix_prompt, obs_x, obs_t, obs_y):
    if args.data == 'Twins':
        full_prompt = f"{prefix_prompt}\n\n \
                Task: \n\
                Given the meanings of existing confounders, based on your world knowledge, please propose ONE additional confounder which BOTH affects the Treatment (whether the baby is heavier or lighter than the other twin) and Outcome (whether the baby twin died within the first year of life).\n\
                Make sure that the proposed confounder has a DIFFERENT meaning compared to existing confounders. \n\n\
                The values of existing confounders, treatments, and outcomes for N={N_individual} individuals are given by: \n\
                Confounders:\n{obs_x}\n\
                Treatments:\n{obs_t}\n\
                Outcomes:\n{obs_y}\n\n\
                For this ONE proposed confounder, please provide:\n\
                1. A clear name for the confounder\n\
                2. The data type of this confounder (choose ONE from the following):\n\
                - 'continuous' (for continuous numerical variables)\n\
                - 'discrete' (for discrete numerical variables with multiple possible values)\n\
                - 'binary' (for binary variables with only 0 and 1 as possible values)\n\
                3. A brief explanation of why it affects both treatment and outcome\n\
                4. A detailed description of the value range and meaning of the values for this confounder\n\n\
                Please format your response EXACTLY as follows:\n\
                @Confounder name@\n\
                Type: [continuous/discrete/binary]\n\
                \n\
                Explanation: [Explain why this confounder affects both treatment and outcome]\n\
                \n\
                Value Description: [Describe the value range and meaning of the values, such as min/max for continuous, possible values for discrete, or 0/1 meaning for binary]\n\
                \n\
                IMPORTANT: The type must be one of: 'continuous', 'discrete', or 'binary'."
    else:
        full_prompt = f"{prefix_prompt}\n\n \
                Task: \n\
                Given the meanings of existing confounders, based on your world knowledge, please propose ONE additional confounder which BOTH affects the Treatment (whether the individual participated in the job training program) and Outcome (whether the individual was employed after the training program).\n\
                Make sure that the proposed confounder has a DIFFERENT meaning compared to existing confounders. \n\n\
                The values of existing confounders, treatments, and outcomes for N={N_individual} individuals are given by: \n\
                Confounders:\n{obs_x}\n\
                Treatments:\n{obs_t}\n\
                Outcomes:\n{obs_y}\n\n\
                For this ONE proposed confounder, please provide:\n\
                1. A clear name for the confounder\n\
                2. The data type of this confounder (choose ONE from the following):\n\
                    - 'continuous' (for continuous numerical variables)\n\
                    - 'discrete' (for discrete numerical variables with multiple possible values)\n\
                    - 'binary' (for binary variables with only 0 and 1 as possible values)\n\
                3. A brief explanation of why it affects both treatment and outcome\n\
                4. A detailed description of the value range and meaning of the values for this confounder\n\n\
                Please format your response EXACTLY as follows:\n\
                @Confounder name@\n\
                Type: [continuous/discrete/binary]\n\
                \n\
                Explanation: [Explain why this confounder affects both treatment and outcome]\n\
                \n\
                Value Description: [Describe the value range and meaning of the values, such as min/max for continuous, possible values for discrete, or 0/1 meaning for binary]\n\
                \n\
                IMPORTANT: The type must be one of: 'continuous', 'discrete', or 'binary'."


    while True:
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": full_prompt}],
            temperature=temperature
        )
        
        text = response.choices[0].message.content

        print(text)

        cname_match = re.search(r'@(.*?)@', text)
        if not cname_match:
            print('Cannot identify confounder name, retrying...')
            continue
        
        cname = cname_match.group(1)
        
        lines = text.split('\n')
        ctype = None
        value_description = ""
        for i, line in enumerate(lines):
            if line.strip().startswith('Type:'):
                if "continuous" in line.lower():
                    ctype = "continuous"
                elif "discrete" in line.lower():
                    ctype = "discrete"
                elif "binary" in line.lower():
                    ctype = "binary"
            
            if "Value Description:" in line:
                start_idx = i
                end_idx = len(lines)

                desc_lines = lines[start_idx:end_idx]
                value_description = '\n'.join(desc_lines).strip()
                break
        
        if ctype is None:
            print('Confounder type not recognized, retrying...')
            continue
        
        if not value_description:
            print('Value description not found, retrying...')
            continue
                
        return {"name": cname, "type": ctype, "value_description": value_description}

@retry(stop=stop_after_attempt(100), wait=wait_exponential(multiplier=1, min=2, max=300))
def Generate_Continuous_Confounder(prefix_prompt, obs_x, obs_t, obs_y, cname, value_description=""):
    
    def sample_from_distributions(means, stds, size=1):
        samples = []
        for mean, std in zip(means, stds):
            sample = np.random.normal(loc=mean, scale=std, size=size)[0]
            samples.append(sample)
        return samples
    
    value_guidance = ""
    if value_description:
        value_guidance = f"\n\nThe value range and meaning for this confounder are described as follows:\n{value_description}\n\nPlease ensure your mean and standard deviation values align with this description."
    
    full_prompt = f"{prefix_prompt}\n\n \
            Task: \n\
            For the continuous confounder named '{cname}' that you proposed, please specify a normal distribution (mean and standard deviation) for each individual from which we can sample the confounder value.{value_guidance}\n\n\
            The values of existing confounders, treatments, and outcomes for N={N_individual} individuals are given by: \n\
            Confounders:\n{obs_x}\n\
            Treatments:\n{obs_t}\n\
            Outcomes:\n{obs_y}\n\n\
            Please format your response EXACTLY as follows:\n\
            \n\
            Overall Explanation: [Explain your approach to determining these distributions]\n\
            \n\
            Value:\n\
            Individual 1: mean=X.XX, std=Y.YY\n\
            Individual 2: mean=X.XX, std=Y.YY\n\
            ...\n\
            Individual {N_individual}: mean=X.XX, std=Y.YY\n\
            \n\
            Summary: [Summarize key patterns in your distribution assignments]\n\
            \n\
            IMPORTANT: You MUST provide exactly {N_individual} distributions, one for each individual.\
                       Display COMPLETE results for ALL {N_individual} individuals without omission."

    while True:
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": full_prompt}],
            temperature=temperature
        )
        
        text = response.choices[0].message.content

        print(text)
                
        individual_lines = []
        for line in text.split('\n'):
            if line.strip().startswith('Individual '):
                individual_lines.append(line.strip())
        
        if len(individual_lines) != N_individual:
            print(f'Individual count mismatch, expected {N_individual}, got {len(individual_lines)}, retrying...')
            continue
            
        means = []
        stds = []
        parsing_error = False
        
        for i, line in enumerate(individual_lines):
            match = re.search(r'Individual \d+:\s*mean=([\d.-]+),\s*std=([\d.-]+)', line)
            if not match:
                print(f'Cannot parse distribution for individual {i+1}, retrying...')
                parsing_error = True
                break
                
            try:
                mean = float(match.group(1))
                std = float(match.group(2))               
                means.append(mean)
                stds.append(std)
            except ValueError:
                print(f'Numeric parsing error for individual {i+1}, retrying...')
                parsing_error = True
                break
        
        if parsing_error:
            continue
            
        if len(means) == N_individual and len(stds) == N_individual:
            print(f"Successfully extracted distributions for continuous confounder '{cname}'")
            sampled_values = sample_from_distributions(means, stds)
            break
        else:
            print(f'Extracted {len(means)} distributions, expected {N_individual}, retrying...')
    
    return sampled_values

@retry(stop=stop_after_attempt(100), wait=wait_exponential(multiplier=1, min=2, max=300))
def Generate_Discrete_Confounder(prefix_prompt, obs_x, obs_t, obs_y, cname, value_description=""):
    
    value_guidance = ""
    if value_description:
        value_guidance = f"\n\nThe value range and meaning for this confounder are described as follows:\n{value_description}\n\nPlease ensure your values align with this description."
    
    full_prompt = f"{prefix_prompt}\n\n \
            Task: \n\
            For the discrete confounder named '{cname}' that you proposed, please specify a discrete value for each individual.{value_guidance}\n\n\
            The values of existing confounders, treatments, and outcomes for N={N_individual} individuals are given by: \n\
            Confounders:\n{obs_x}\n\
            Treatments:\n{obs_t}\n\
            Outcomes:\n{obs_y}\n\n\
            Please provide the possible values and their meaning for this discrete variable first, then assign a specific value to each individual.\n\n\
            Please format your response EXACTLY as follows:\n\
            \n\
            Overall Explanation: [Explain your approach and the meaning of possible values]\n\
            \n\
            Value\n\
            Individual 1: value=X\n\
            Individual 2: value=X\n\
            ...\n\
            Individual {N_individual}: value=X\n\
            \n\
            Summary: [Summarize key patterns in your value assignments]\n\
            \n\
            IMPORTANT: You MUST provide exactly {N_individual} values, one for each individual.\
                Display COMPLETE results for ALL {N_individual} individuals without omission."

    while True:
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": full_prompt}],
            temperature=temperature
        )
        
        text = response.choices[0].message.content

        print(text)
                
        individual_lines = []
        for line in text.split('\n'):
            if line.strip().startswith('Individual '):
                individual_lines.append(line.strip())
        
        if len(individual_lines) != N_individual:
            print(f'Individual count mismatch, expected {N_individual}, got {len(individual_lines)}, retrying...')
            continue
            
        discrete_values = []
        parsing_error = False
        
        for i, line in enumerate(individual_lines):
            match = re.search(r'Individual \d+:\s*value=([\d.-]+)', line)
            if not match:
                print(f'Cannot parse discrete value for individual {i+1}, retrying...')
                parsing_error = True
                break
                
            try:
                value = float(match.group(1))
                discrete_values.append(value)
            except ValueError:
                print(f'Numeric parsing error for individual {i+1}, retrying...')
                parsing_error = True
                break
        
        if parsing_error:
            continue
            
        if len(discrete_values) == N_individual:
            print(f"Successfully extracted values for discrete confounder '{cname}'")
            break
        else:
            print(f'Extracted {len(discrete_values)} values, expected {N_individual}, retrying...')
    
    return discrete_values

@retry(stop=stop_after_attempt(100), wait=wait_exponential(multiplier=1, min=2, max=300))
def Generate_Binary_Confounder(prefix_prompt, obs_x, obs_t, obs_y, cname, value_description=""):
    
    value_guidance = ""
    if value_description:
        value_guidance = f"\n\nThe value range and meaning for this confounder are described as follows:\n{value_description}\n\nPlease ensure your binary values (0 or 1) align with this description."
    
    full_prompt = f"{prefix_prompt}\n\n \
            Task: \n\
            For the binary confounder named '{cname}' that you proposed, please specify a binary value (0 or 1) for each individual.{value_guidance}\n\n\
            Please indicate what 0 and 1 represent for this binary variable (e.g., 0=absent, 1=present).\n\n\
            The values of existing confounders, treatments, and outcomes for N={N_individual} individuals are given by: \n\
            Confounders:\n{obs_x}\n\
            Treatments:\n{obs_t}\n\
            Outcomes:\n{obs_y}\n\n\
            Please format your response EXACTLY as follows:\n\
            \n\
            Overall Explanation: [Explain your approach and the meaning of values 0 and 1]\n\
            \n\
            Values:\n\
            Individual 1: value=X\n\
            Individual 2: value=X\n\
            ...\n\
            Individual {N_individual}: value=X\n\
            \n\
            Summary: [Summarize key patterns in your value assignments]\n\
            \n\
            IMPORTANT: X MUST be either 0 or 1 for each individual. You MUST provide exactly {N_individual} values, one for each individual.\
                Display COMPLETE results for ALL {N_individual} individuals without omission, even if the individuals have the same pattern or value."

    while True:
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": full_prompt}],
            temperature=temperature
        )
        
        text = response.choices[0].message.content

        print(text)
                
        individual_lines = []
        for line in text.split('\n'):
            if line.strip().startswith('Individual '):
                individual_lines.append(line.strip())
        
        if len(individual_lines) != N_individual:
            print(f'Individual count mismatch, expected {N_individual}, got {len(individual_lines)}, retrying...')
            continue
            
        binary_values = []
        parsing_error = False
        
        for i, line in enumerate(individual_lines):
            match = re.search(r'Individual \d+:\s*value=([01])', line)
            if not match:
                print(f'Cannot parse binary value for individual {i+1}, retrying...')
                parsing_error = True
                break
                
            try:
                value = int(match.group(1))
                if value not in [0, 1]:
                    print(f'Value for individual {i+1} is not binary (0 or 1), retrying...')
                    parsing_error = True
                    break
                binary_values.append(value)
            except ValueError:
                print(f'Numeric parsing error for individual {i+1}, retrying...')
                parsing_error = True
                break
        
        if parsing_error:
            continue
            
        if len(binary_values) == N_individual:
            print(f"Successfully extracted values for binary confounder '{cname}'")
            break
        else:
            print(f'Extracted {len(binary_values)} values, expected {N_individual}, retrying...')
    
    return binary_values


def Concat_Confounder(obs_x, cname, sampled_values):
    lines = obs_x.strip().split("\n")
    lines[0] += f",{cname}"
    
    for i in range(1, len(lines)):
        lines[i] += f",{sampled_values[i-1]}"  

    aug_x = "\n".join(lines)
    return aug_x

def Statistical_Ignorability_Test(counter_y: str, obs_x: str, obs_y: str, obs_t: str):
    obs_x = pd.read_csv(io.StringIO(obs_x))
    obs_y = pd.read_csv(io.StringIO(obs_y))
    counter_y = pd.read_csv(io.StringIO(counter_y))
    obs_t = pd.read_csv(io.StringIO(obs_t))

    Y_0 = obs_y.copy()
    Y_1 = counter_y.copy()

    obs_t.columns = obs_t.columns.str.strip()
    Y_0[obs_t["treatment"] == 1] = counter_y[obs_t["treatment"] == 1]
    Y_1[obs_t["treatment"] == 1] = obs_y[obs_t["treatment"] == 1]

    Y = pd.concat([Y_0, Y_1], axis=1).values 
    T = obs_t.values.reshape(-1, 1)          
    X = obs_x.values                       

    kci = KCI_CInd()  
    p_value, _ = kci.compute_pvalue(T, Y, X)  

    print(f"p-Value: {p_value}")


    alpha = 0.1
    if p_value > alpha:
        return True  
    else:
        return False  
    
    


def batch_process_data(csv_file_path, output_path, max_batch_size=300, max_confounders=5, aug_num=5):
    """
    Batch process large datasets with an iterative approach:
    1. Generate one confounder based on the first batch
    2. Generate values for this confounder across all batches
    3. Generate counterfactual outcomes for all batches
    4. Run ignorability test on the entire dataset
    5. If test fails, add another confounder (repeat steps 1-4)
    6. Continue until ignorability is achieved or max_confounders is reached
    
    Parameters:
    - csv_file_path: Input data file path
    - output_path: Output data file path
    - max_batch_size: Maximum number of samples to process per batch
    - max_confounders: Maximum number of confounders to add (default: 5)
    """
    global N_individual
    global model
    print(f"Reading full dataset: {csv_file_path}")
    # Read the full dataset
    full_df = pd.read_csv(csv_file_path)
    total_individuals = len(full_df)
    print(f"Total samples in dataset: {total_individuals}")
    
    # Determine the number of batches needed
    num_batches = (total_individuals + max_batch_size - 1) // max_batch_size  # Round up
    print(f"Processing data in {num_batches} batches, with maximum {max_batch_size} samples per batch")
    
    # Read confounder descriptions
    with open(f'dataset/{args.data}/covar_desc.txt', 'r') as f:
        attr_info = f.read()  
    
    # Record confounder information to maintain consistency
    confounder_info_list = []
    
    # Iteratively add confounders and test ignorability
    is_ignorable = False
    num_confounders_added = 0
    final_batch_results = []
    
    # Generate representative prefix prompt
    representative_batch_size = min(max_batch_size, total_individuals)
    if args.data == 'Twins':
        representative_prompt = f"Introduction of Twins Dataset: \n \
            The Twins dataset is widely used in causal inference research for estimating Average Treatment Effects (ATE). \n \
            It is derived from the Linked Birth and Infant Death Data (LBIDD) provided by the U.S. Centers for Disease Control and Prevention (CDC). \n \
            The dataset contains information about {representative_batch_size} babies, where the treatment is defined as whether the baby is heavier (T=1) or lighter (T=0) than the other twin.\n \
            The outcome of interest is the mortality status within the first year of life. \n \
            The Twins dataset contains:\n  \
            (1) Treatment Value (T): T ∈ {{0, 1}} indicating whether the baby is heavier (T=1) or lighter (T=0) than the other twin.\n \
            (2) Observed Outcome (Y): Whether the baby twin died within the first year of life, given the treatment T. 0 means the baby survived, and 1 means the baby died. \n \
            (3) Confounders (X1 to Xn): Features affecting both the treatment and outcome. The information of confounder in the form 'key':'meaning' as follows: {attr_info} \n "
    else:
        representative_prompt = f"Introduction of Jobs Dataset: \n \
            The Jobs dataset is widely used in causal inference research for estimating treatment effects on employment outcomes. \n \
            It is derived from a randomized evaluation of a job training program conducted in the United States. \n \
            The dataset contains information about {representative_batch_size} individuals, where the treatment is defined as whether the individual participated in the job training program (T=1) or not (T=0).\n \
            The outcome of interest is the individual's employment status observed after the treatment decision. \n \
            The Jobs dataset contains:\n  \
            (1) Treatment Value (T): T ∈ {{0, 1}} indicating whether the individual participated in the job training program (T=1) or not (T=0).\n \
            (2) Observed Outcome (Y): Whether the individual was unemployed after the treatment decision. 1 means employed, and 0 means unemployed. \n \
            (3) Confounders (X1 to Xn): Features affecting both the treatment and outcome. The information of confounder in the form 'key':'meaning' as follows: {attr_info} \n "

    # Initialize batch data cache to avoid reloading for each iteration
    batch_data_cache = []
    for batch_idx in range(num_batches):
        start_idx = batch_idx * max_batch_size
        end_idx = min((batch_idx + 1) * max_batch_size, total_individuals)
        current_batch_size = end_idx - start_idx
        
        # Save batch data temporarily
        os.makedirs(f"dataset/{args.data}/temp", exist_ok=True)
        batch_csv_path = f"dataset/{args.data}/temp/temp_batch_{batch_idx+1}.csv"
        batch_df = full_df.iloc[start_idx:end_idx]
        batch_df.to_csv(batch_csv_path, index=False)
        
        # Process the batch data
        obs_x, obs_y, obs_t, obs_cfy, obs_y0, obs_y1 = read_observed_data(batch_csv_path)
        
        # Store the data in cache
        batch_data_cache.append({
            'csv_path': batch_csv_path,
            'batch_size': current_batch_size,
            'obs_x': obs_x,
            'obs_y': obs_y,
            'obs_t': obs_t,
            'obs_cfy': obs_cfy,
            'obs_y0': obs_y0,
            'obs_y1': obs_y1
        })


    while not is_ignorable and num_confounders_added < max_confounders:
        for _ in range(aug_num):
            print(f"\n===== Adding confounder {num_confounders_added+1}/{max_confounders} =====")
            
            # Step 1: Generate a new confounder based on the first batch
            print(f"\n--- Generating new confounder concept ---")
            
            # Use the first batch data for confounder proposal
            first_batch = batch_data_cache[0]
            N_individual = first_batch['batch_size']
            
            # Propose a new confounder
            confounder_info = Propose_Confounder(representative_prompt, first_batch['obs_x'], first_batch['obs_t'], first_batch['obs_y'])
            cname = confounder_info['name']
            ctype = confounder_info['type']
            value_description = confounder_info['value_description']
            
            # Add to our list of confounders
            confounder_info_list.append(confounder_info)
            num_confounders_added += 1
            
            # Step 2: Generate values for this confounder across all batches
            for batch_idx, batch_data in enumerate(batch_data_cache):
                print(f"\n--- Generating values for confounder '{cname}' ({ctype}) in batch {batch_idx+1}/{len(batch_data_cache)} ---")
                
                # Set the number of individuals for this batch
                N_individual = batch_data['batch_size']
                
                # Create batch-specific prefix prompt
                if args.data == 'Twins':
                    batch_prompt = f"Introduction of Twins Dataset: \n \
                        The Twins dataset is widely used in causal inference research for estimating Average Treatment Effects (ATE). \n \
                        It is derived from the Linked Birth and Infant Death Data (LBIDD) provided by the U.S. Centers for Disease Control and Prevention (CDC). \n \
                        The dataset contains information about {N_individual} babies, where the treatment is defined as whether the baby is heavier (T=1) or lighter (T=0) than the other twin.\n \
                        The outcome of interest is the mortality status within the first year of life. \n \
                        The Twins dataset contains:\n  \
                        (1) Treatment Value (T): T ∈ {{0, 1}} indicating whether the baby is heavier (T=1) or lighter (T=0) than the other twin.\n \
                        (2) Observed Outcome (Y): Whether the baby twin died within the first year of life, given the treatment T. 0 means the baby survived, and 1 means the baby died. \n \
                        (3) Confounders (X1 to Xn): Features affecting both the treatment and outcome. The information of confounder in the form 'key':'meaning' as follows: {attr_info} \n "
                else:
                    batch_prompt = f"Introduction of Jobs Dataset: \n \
                        The Jobs dataset is widely used in causal inference research for estimating Average Treatment Effects (ATE). \n \
                        It is derived from a randomized evaluation of a job training program conducted in the United States. \n \
                        The dataset contains information about {N_individual} individuals, where the treatment is defined as whether the individual participated in the job training program (T=1) or not (T=0).\n \
                        The outcome of interest is the individual's employment status observed after the treatment decision. \n \
                        The Jobs dataset contains:\n  \
                        (1) Treatment Value (T): T ∈ {{0, 1}} indicating whether the individual participated in the job training program (T=1) or not (T=0).\n \
                        (2) Observed Outcome (Y): Whether the individual was unemployed after the treatment decision. 1 means employed, and 0 means unemployed. \n \
                        (3) Confounders (X1 to Xn): Features affecting both the treatment and outcome. The information of confounder in the form 'key':'meaning' as follows: {attr_info} \n "

                # Generate variable values based on type
                if ctype == "continuous":
                    sampled_values = Generate_Continuous_Confounder(batch_prompt, batch_data['obs_x'], batch_data['obs_t'], batch_data['obs_y'], cname, value_description)
                elif ctype == "discrete":
                    sampled_values = Generate_Discrete_Confounder(batch_prompt, batch_data['obs_x'], batch_data['obs_t'], batch_data['obs_y'], cname, value_description)
                elif ctype == "binary":
                    sampled_values = Generate_Binary_Confounder(batch_prompt, batch_data['obs_x'], batch_data['obs_t'], batch_data['obs_y'], cname, value_description)
                
                # Add confounder to features
                batch_data['obs_x'] = Concat_Confounder(batch_data['obs_x'], cname, sampled_values)
        
        # Step 3: Generate counterfactual outcomes for all batches
        all_batches_x = []
        all_batches_y = []
        all_batches_t = []
        all_batches_cfy = []
        
        for batch_idx, batch_data in enumerate(batch_data_cache):
            print(f"\n--- Generating counterfactual outcomes for batch {batch_idx+1}/{len(batch_data_cache)} ---")
            
            # Set the number of individuals for this batch
            N_individual = batch_data['batch_size']
            
            # Create batch-specific prefix prompt
            if args.data == 'Twins':
                batch_prompt = f"Introduction of Twins Dataset: \n \
                    The Twins dataset is widely used in causal inference research for estimating Average Treatment Effects (ATE). \n \
                    It is derived from the Linked Birth and Infant Death Data (LBIDD) provided by the U.S. Centers for Disease Control and Prevention (CDC). \n \
                    The dataset contains information about {N_individual} babies, where the treatment is defined as whether the baby is heavier (T=1) or lighter (T=0) than the other twin.\n \
                    The outcome of interest is the mortality status within the first year of life. \n \
                    The Twins dataset contains:\n  \
                    (1) Treatment Value (T): T ∈ {{0, 1}} indicating whether the baby is heavier (T=1) or lighter (T=0) than the other twin.\n \
                    (2) Observed Outcome (Y): Whether the baby twin died within the first year of life, given the treatment T. 0 means the baby survived, and 1 means the baby died. \n \
                    (3) Confounders (X1 to Xn): Features affecting both the treatment and outcome. The information of confounder in the form 'key':'meaning' as follows: {attr_info} \n "
            else:
                batch_prompt = f"Introduction of Jobs Dataset: \n \
                    The Jobs dataset is widely used in causal inference research for evaluating the performance of treatment effect estimation algorithms. \n \
                    It is constructed by combining experimental and observational data from the National Supported Work (NSW) demonstration and comparison group data (e.g., PSID or CPS). \n \
                    The dataset contains information about {N_individual} individuals, where the treatment is defined as whether a person participated in the job training program (T=1) or not (T=0).\n \
                    The outcome of interest is the individual's employment after the program. \n \
                    The Jobs dataset contains:\n  \
                    (1) Treatment Value (T): T ∈ {{0, 1}} indicating whether the individual participated (T=1) or did not participate (T=0) in the job training program.\n \
                    (2) Observed Outcome (Y): The individual's employment observed after the treatment decision. Y=1 means employed, Y=0 means not employed. \n \
                    (3) Confounders (X1 to Xn): Features affecting both the treatment assignment and the outcome, such as age, education, prior income, etc. The information of confounder in the form 'key':'meaning' as follows: {attr_info} \n "


            # Generate counterfactual outcomes
            counter_y = Counterfactual_Query(batch_prompt, batch_data['obs_x'], batch_data['obs_y'], batch_data['obs_t'])
            batch_data['counter_y'] = counter_y
            
            # Store for ignorability test
            all_batches_x.append(batch_data['obs_x'])
            all_batches_y.append(batch_data['obs_y'])
            all_batches_t.append(batch_data['obs_t'])
            all_batches_cfy.append(counter_y)
        
        # Step 4: Run ignorability test on representative subset
        print("\n--- Running ignorability test after adding confounder ---")
        
        # Use entire dataset for testing
        def combined_variables(variable:str):
            variable_list = [batch_data_cache[i][variable] for i in range(len(batch_data_cache))]  

            all_data_lines = []
            header = ''
            for s in variable_list:
                lines = s.split('\n')
                header = lines[0]
                all_data_lines.extend(lines[1:])  

            combined_data = f"{header}\n" + "\n".join(all_data_lines)
            return combined_data


        test_result = Statistical_Ignorability_Test(
            combined_variables('counter_y'),
            combined_variables('obs_x'),
            combined_variables('obs_y'),
            combined_variables('obs_t'),
        )

        if test_result:
            print(f"✓ Ignorability test PASSED with {num_confounders_added} confounders!")
            is_ignorable = True
        else:
            print(f"× Ignorability test FAILED with {num_confounders_added} confounders.")
            if num_confounders_added < max_confounders:
                print("Adding another confounder...")
            else:
                print(f"Reached maximum number of confounders ({max_confounders}). Continuing with current set.")
    
    # Process and save all batch results
    all_batch_results = []
    
    for batch_idx, batch_data in enumerate(batch_data_cache):
        print(f"\n--- Finalizing batch {batch_idx+1}/{len(batch_data_cache)} ---")
        
        # Convert results to DataFrame
        obs_t_df = pd.read_csv(io.StringIO(batch_data['obs_t']))
        obs_y_df = pd.read_csv(io.StringIO(batch_data['obs_y']))
        obs_x_df = pd.read_csv(io.StringIO(batch_data['obs_x']))
        obs_y_cf_df = pd.read_csv(io.StringIO(batch_data['counter_y']))
        obs_y_0_df = pd.read_csv(io.StringIO(batch_data['obs_y0']))
        obs_y_1_df = pd.read_csv(io.StringIO(batch_data['obs_y1']))
        
        # Combine results
        batch_result_df = pd.concat([obs_t_df, obs_y_df, obs_y_cf_df, obs_y_0_df, obs_y_1_df, obs_x_df], axis=1)
        batch_result_df.columns = batch_result_df.columns.str.strip()
        
        # Add to overall results list
        all_batch_results.append(batch_result_df)
        
        # Save current batch results (for debugging or intermediate result checking)
        os.makedirs(f'dataset/Augment_data/{args.data}/{save_model_name}/batches', exist_ok=True)
        
        batch_output_path = f"dataset/Augment_data/{args.data}/{save_model_name}/batches/batch_{batch_idx+1}_aug.csv"
        os.makedirs(os.path.dirname(batch_output_path), exist_ok=True)
        batch_result_df.to_csv(batch_output_path, index=False)
        print(f"Batch {batch_idx+1} processing complete, results saved to: {batch_output_path}")
        
        # Delete temporary file
        # os.remove(batch_data['csv_path'])
    
    # Combine all batch results
    final_result_df = pd.concat(all_batch_results, axis=0)
    print(f"\nAll batches processed, total samples: {len(final_result_df)}")
    
    # Print confounder information
    print(f"\nConfounders added ({len(confounder_info_list)}):")
    for i, info in enumerate(confounder_info_list):
        print(f"{i+1}. {info['name']} (Type: {info['type']})")
    
    # Save final results
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    final_result_df.to_csv(output_path, index=False)
    print(f"Final results saved to: {output_path}")
    
    return final_result_df


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description='hparams')
    parser.add_argument('--model', type=str, default='gpt-4o', 
                        help="select from ['gpt-4o', 'deepseek-r1', 'llama', 'qwen']")
    parser.add_argument('--data', type=str, default='Twins', 
                        help="select from ['Twins', 'Jobs']")
    args = parser.parse_args()


    model = args.model
    client = OpenAI(
            base_url='your_url',
            api_key='your_api_key',
        )

    temperature = 0.7
    

    save_model_name = model
    
    aug_num = 10
    
    # Batch processing large dataset mode    
    input_file = f"dataset/{args.data}/Original_data.csv"
    output_file = f"dataset/Augment_data/{args.data}/{save_model_name}/Augment_data.csv"
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    # Maximum samples processed per batch
    max_batch_size = 150
    # Maximum number of confounders added per dataset
    max_confounders = 1
    aug_num = 1
    # Execute batch processing
    result_df = batch_process_data(
        csv_file_path=input_file, 
        output_path=output_file, 
        max_batch_size=max_batch_size, 
        max_confounders=max_confounders,
        aug_num=aug_num
    )
    
    print(f"All data processed! Total samples: {len(result_df)}")
    print(f"Results saved to: {output_file}")
    print(f"Maximum number of confounders added: {max_confounders}")


