from openai import OpenAI
import pandas as pd
import csv
import numpy as np
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='Generate synthetic patient data using LLM')
    parser.add_argument('--input_file', type=str, default='./data/sampled_30.csv',
                        help='Path to input CSV file for reference')
    parser.add_argument('--output_file', type=str, default='./data/disease_n_30.csv',
                        help='Path to save the generated synthetic data')
    parser.add_argument('--consistency_rules_file', type=str, default='./outputs/consistency_result.jsonl',
                        help='Path to consistency rules file')
    return parser.parse_args()

def load_consistency_rules(file_path):
    """Load consistency rules from jsonl file"""
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read().strip()
        # Remove outer quotes if present
        if content.startswith('"') and content.endswith('"'):
            content = content[1:-1]
        # Replace escaped newlines with actual newlines
        content = content.replace('\\n', '\n')
        return content

def main():
    args = parse_args()
    
    '''####### Your API key and base url ########'''
    client = OpenAI(
        api_key="Your api key",
        base_url="Your chosen base_url"
    )


    # Load reference data from input file
    df = pd.read_csv(args.input_file)
    consistency_rule = load_consistency_rules(args.consistency_rules_file)


    prompt = f"""
Your task is to generate 30 realistic yet diverse synthetic patient records in CSV format related to Alzheimer's Disease. \
Each record should accurately reflect real-world causal relationships and distributions.

### Critical Diagnostic Rules (**Not Strictly Follow**):
{consistency_rule}
### Summary of Feature Attributes:
- Diagnosis: Binary, Alzheimer's Disease status (0=No, 1=Yes).
- PatientID: Unique identifier (range 0-6900).
- Age: Patient's age (60-90 years).
- Gender: Binary, 0=Male, 1=Female.
- Ethnicity: Categorical (0=Caucasian, 1=African American, 2=Asian, 3=Other).
- EducationLevel: Ordinal education scale (0=None, 1=High School, 2=Bachelor's, 3=Higher).
- BMI: Body Mass Index (15-40).
- Smoking: Binary, 0=No, 1=Yes.
- AlcoholConsumption: Weekly alcohol units (0-20).
- PhysicalActivity: Weekly physical activity hours (0-10).
- DietQuality: Diet quality score (0-10).
- SleepQuality: Sleep quality score (4-10).
- FamilyHistoryAlzheimers: Binary, family history of Alzheimer's (0=No, 1=Yes).
- CardiovascularDisease: Binary (0=No, 1=Yes).
- Diabetes: Binary (0=No, 1=Yes).
- Depression: Binary (0=No, 1=Yes).
- HeadInjury: Binary (0=No, 1=Yes).
- Hypertension: Binary (0=No, 1=Yes).
- SystolicBP: Systolic blood pressure (90-180 mmHg).
- DiastolicBP: Diastolic blood pressure (60-120 mmHg).
- CholesterolTotal: Total cholesterol (150-300 mg/dL).
- CholesterolLDL: LDL cholesterol (50-200 mg/dL).
- CholesterolHDL: HDL cholesterol (20-100 mg/dL).
- CholesterolTriglycerides: Triglycerides (50-400 mg/dL).
- MMSE: Mini-Mental State Examination score (0-30; lower scores indicate cognitive impairment).
- FunctionalAssessment: Functional assessment score (0-10; lower scores indicate impairment).
- MemoryComplaints: Binary (0=No, 1=Yes).
- BehavioralProblems: Binary (0=No, 1=Yes).
- ADL: Activities of Daily Living score (0-10; lower scores indicate impairment).
- Confusion: Binary (0=No, 1=Yes).
- Disorientation: Binary (0=No, 1=Yes).
- PersonalityChanges: Binary (0=No, 1=Yes).
- DifficultyCompletingTasks: Binary (0=No, 1=Yes).
- Forgetfulness: Binary (0=No, 1=Yes).

### Important Instructions:
- Generate records guided by Critical Diagnostic Rules, allowing reasonable flexibility..
- Innovatively vary other features not explicitly mentioned while preserving plausibility.
- Generate novel yet plausible patient behavior and medical combinations.

### CSV Output Format:
Diagnosis,PatientID,Age,Gender,Ethnicity,EducationLevel,BMI,Smoking,AlcoholConsumption,PhysicalActivity,DietQuality,SleepQuality,FamilyHistoryAlzheimers,CardiovascularDisease,Diabetes,Depression,HeadInjury,Hypertension,SystolicBP,DiastolicBP,CholesterolTotal,CholesterolLDL,CholesterolHDL,CholesterolTriglycerides,MMSE,FunctionalAssessment,MemoryComplaints,BehavioralProblems,ADL,Confusion,Disorientation,PersonalityChanges,DifficultyCompletingTasks,Forgetfulness
0,5579,71,1,0,1,39.677200017551485,1,11.337109666689928,8.861771133724414,2.107593698690557,9.362158745518007,0,0,0,0,1,0,124,83,225.19982437493013,156.3886910279693,78.99103377364177,83.36409266936745,9.1682859831823,5.204142076973774,0,0,3.9054729731389406,0,0,0,0,0
0,5005,79,1,1,0,25.971216204415374,0,10.001860725392683,0.745796035263594,9.35528666307197,4.86637837048713,0,0,0,1,0,0,129,79,175.19233401496004,125.10145107880324,52.13032163640317,231.3947582339809,19.739525621162176,6.931808947595978,0,0,1.927703387503894,0,0,0,0,0
0,5107,64,0,0,2,30.917727744710376,0,15.593466456981966,2.441693436341085,2.658215656046534,7.74310016492944,0,0,0,0,0,0,90,114,270.5702298919015,176.0514379848779,85.62870368791239,127.96491875605444,22.111082657543378,6.665563753887361,1,0,4.465643900359137,0,0,1,0,0
0,5756,84,1,1,1,19.151933955347488,0,1.2996860965818735,6.323989526468504,8.622826587509884,7.850919107991995,0,0,0,0,1,0,130,78,271.8246591069959,182.81153589961528,50.034535731724375,58.45062134820762,5.090295222629441,4.712646906718977,0,0,8.435555769082017,0,1,0,0,0
0,5153,76,0,3,3,37.24808884988575,0,12.648325014814422,2.549975383213398,6.226853217332627,4.5512495116879785,0,0,0,1,0,0,106,69,206.77533270138656,94.00291130028484,64.90466514976791,131.39076528776258,29.269283663726497,1.784411265200484,0,1,5.284833879466832,0,0,0,0,1
0,5626,82,1,0,1,25.29846198373583,1,5.049710006096609,0.0454054872476972,3.118052946927932,7.489547961993016,0,0,0,1,0,0,153,100,200.24604840330036,66.59997844145343,42.77820779457875,283.8278014912178,17.117807679138267,7.856078367255526,1,0,8.537764565515229,0,0,0,0,0
0,6776,74,1,1,0,19.42499784962132,0,0.1755179376067217,3.47507074401808,2.544481654537286,5.151360813813252,0,0,0,0,1,0,105,63,215.08189364319395,167.00811613853415,43.16218665476405,347.4668974820227,6.331502628072112,0.9798597096298012,0,0,7.476960032497645,0,1,1,1,0
0,5572,69,0,2,1,21.44811977496177,0,7.253156608320552,2.9825314905489755,8.94705523958584,8.508751155296244,0,0,1,0,0,0,106,104,257.8304023325624,164.69917836436343,21.104898236011067,306.79714795349406,21.56491159788373,9.113111160665683,0,0,2.6131940730943626,0,0,0,1,0
0,6136,68,0,0,3,33.60338305000154,0,7.798858755139097,2.7933225421443764,5.882181264666116,7.227692270134742,0,1,0,0,0,0,105,63,185.6868199782167,50.89191007863555,45.47251786671956,168.37882859717269,23.333575753444595,7.312664706638221,0,0,6.507293325339125,1,0,0,0,1
0,6348,83,0,1,1,17.973245869949547,0,17.197310267297162,0.9004968960447313,1.4955207154973016,4.68199489336253,0,0,0,0,0,0,167,90,247.64417484236048,147.2012723375561,25.51865107771663,300.5607139759637,19.690081511996578,7.107829334692999,0,0,9.627375495304335,1,0,0,1,0
0,5776,77,1,1,1,27.390979759194025,1,16.915396689159913,8.054228467102027,8.994876654180965,4.809943326712961,1,0,0,0,0,0,90,101,219.78699384255503,159.46138489180953,32.58587625015303,294.3223144222098,17.983823676206597,8.981968515542748,0,0,0.90130807783941,0,0,0,0,1
0,6541,63,1,0,3,37.07794365859406,1,16.909750667886655,5.699390943809272,7.164247412161105,5.792930044480396,1,0,0,0,0,0,0,138,119,255.553512764201,70.49404969721286,86.4454067023966,346.9765948534264,25.940092779639624,0.2830858447835138,0,0,5.0783296898430095,0,0,0,1,0
0,5277,89,0,0,2,22.32233492399432,0,17.83480753680989,7.630076330452645,8.798753135970216,7.962467013853606,0,0,1,0,0,0,131,72,289.5714335547823,77.60546206388308,51.17283301978677,283.04835751566486,7.390533358159933,9.168384008554357,0,0,7.986341744083989,1,0,0,0,0
0,5152,85,1,0,0,17.112223917886,1,15.27367007011341,2.4317573175665377,9.1915022344997,6.614969166981309,1,0,1,1,0,0,177,65,270.7311704538462,87.12352537721577,62.84968975839252,239.91061844174743,17.15662826927705,6.906943692290046,0,0,2.396604051322573,1,0,0,1,0
0,4854,88,1,0,1,16.514566734854224,0,16.322112077202103,1.0069766537658764,4.70746710762756,9.38249302529594,1,0,1,0,0,0,112,97,195.94341922089143,159.2484574258704,70.4543307135279,227.6350029614732,24.20557573242864,4.661765419479304,0,1,4.110831569813538,0,1,0,0,1
1,5887,63,1,0,2,30.968971617435468,0,9.035176957074151,9.493289783498945,9.198248290710112,5.182105320561204,0,0,0,0,0,1,132,79,219.70390588498145,143.62894630329427,85.80877495517487,62.12401218044018,14.415136638004782,4.926870504274792,1,0,2.1894509264449926,0,0,0,1,0
1,5719,89,1,3,1,38.866212484431216,0,8.62455412820935,9.869696559439776,4.370038389145444,6.899845630860306,0,1,1,0,0,0,111,117,273.071990019359,193.69538244807828,28.369517626299,129.17800710800535,14.116927960911076,0.1068580913706207,0,0,1.7179200384224569,0,0,0,0,0
1,5053,88,1,0,1,17.647485588168546,0,1.2131935223378876,1.446166633976137,6.002699089235972,7.460286105440367,0,0,0,1,0,0,138,64,187.20345549924292,161.9192920978129,30.96604735854233,245.4290850893635,10.76421935150406,0.639074243014307,0,1,9.091783097614655,0,1,0,0,0
1,6178,70,1,0,1,34.91528643042514,0,2.209860776444472,9.446901906430394,0.1812896914411599,9.72390346754683,0,0,0,0,0,0,129,113,271.8482141463765,136.0905913089838,68.97830641594103,163.5516717570958,8.394910019205936,8.204003172193028,0,1,2.4638378789261277,0,0,0,0,0
1,5092,73,0,0,2,23.963787822360747,0,10.53599794071259,0.1717065343224944,1.2159651831501117,8.49683173997097,0,0,0,0,0,0,113,85,277.5996636666129,166.1942604716594,89.9941835187939,342.70683654428024,17.800582190064457,3.764658664620498,0,0,4.487288894329216,0,0,0,0,0
1,6183,71,1,0,3,28.327030753701543,0,13.341393413749827,3.342403451589877,7.801620757124215,7.7899369095867,0,1,0,0,0,0,144,83,289.9133276246659,195.71583913946245,88.94282691547274,367.9368176375464,0.3435312966681791,2.8875636048506683,1,0,6.602398280180569,0,0,0,0,0
1,5606,70,1,1,1,28.10865428082612,1,8.687297063628051,1.2168690286047112,0.2846484981603669,7.1538219793216085,0,0,0,0,0,0,140,90,156.23229714278253,130.83269136890553,42.17614795107671,231.19831018154773,6.32091521304811,0.217931798635963,0,0,0.2514297256845077,0,0,0,0,0
1,6532,87,0,0,3,25.17331890161245,0,13.18295835752332,0.3381530959583356,6.835913297469584,8.002450428806217,0,0,0,0,0,0,134,85,246.78767555363945,132.3347935188018,41.65665077079805,398.91074353258045,13.437781335642304,2.9999810187749643,0,1,7.7402145058370655,0,0,0,1,1
1,6770,70,0,1,2,20.354777133022782,0,5.889715183397,7.035978806839212,2.952637894649295,4.444632088101517,0,0,0,0,1,0,112,82,238.1152578944317,87.03421628592534,54.925771176056905,75.74020395890776,8.459046872568937,6.161092519081974,0,0,0.2538256095753188,0,0,0,0,0
1,6444,62,0,2,1,24.13164175681785,0,18.35624361515639,9.460848375354482,6.418246099643908,8.784236835428839,0,1,0,1,1,0,135,66,205.27009628761525,110.86748025763126,65.6764305147246,217.42373274812903,9.258446339716173,2.654054391214362,0,1,0.434139434511257,0,0,0,0,0
1,4976,76,1,0,1,27.86539163579045,1,10.671003241734002,8.017559021263565,8.282861860435375,9.135238617958038,1,1,0,0,0,0,99,111,260.4049157425181,167.49056575947802,66.67419217246245,395.4295146934944,3.0052481572292047,3.384758691833944,0,0,0.2108413255037711,0,1,0,0,1
1,6607,65,0,0,1,35.6341485549689,0,13.2319627119798,5.6465093370048525,2.415079836003776,5.919874133234967,0,0,0,0,0,0,170,71,167.37930107306525,177.82928726755597,94.27576053195472,311.97127357174094,22.70302025793628,0.1717863711002165,0,1,3.836315035900393,0,0,0,0,1
1,6121,76,0,0,2,23.671595873398545,0,18.300121065230808,3.148953038840782,8.846350596425212,6.941702942168544,0,0,1,0,1,0,143,70,177.06204286075368,57.35744370296619,24.83139093226457,394.8395897483777,12.350215244542945,2.6952074238666404,0,0,1.3596867417540135,0,0,0,0,0
1,6877,75,1,1,1,35.83212483016938,0,13.036672535301834,2.6049047138650283,4.278959556290383,9.50657665797964,0,0,0,0,1,1,173,87,212.01412376257065,160.94428996494736,60.510029872173504,193.20790325080227,10.644398966209309,3.350043638300988,0,1,2.180182508063119,0,0,0,0,0
1,6763,63,0,0,1,17.76891961167496,0,9.098222460706433,2.576862935026989,5.319946814740533,5.798027614625696,0,0,0,0,0,0,96,94,176.02567534531053,93.22267866462276,76.90720153150679,339.49061593032275,12.349824542555638,7.460747294328897,0,0,8.455109418306929,0,0,0,0,0

Diagnosis,PatientID,Age,Gender,Ethnicity,EducationLevel,BMI,Smoking,AlcoholConsumption,PhysicalActivity,DietQuality,SleepQuality,FamilyHistoryAlzheimers,CardiovascularDisease,Diabetes,Depression,HeadInjury,Hypertension,SystolicBP,DiastolicBP,CholesterolTotal,CholesterolLDL,CholesterolHDL,CholesterolTriglycerides,MMSE,FunctionalAssessment,MemoryComplaints,BehavioralProblems,ADL,Confusion,Disorientation,PersonalityChanges,DifficultyCompletingTasks,Forgetfulness
"""

    print(prompt)

    # Define request parameters
    num_iterations = 20  # Set the number of synthetic datasets you want to generate
    df_len = 30
    all_data = []
    valid_len = []
    i = 0
    
    # Generate multiple datasets
    while len(all_data) < df_len:

        response = client.chat.completions.create(
            model="gpt-3.5-turbo-1106",
            messages=[
                {"role": "system", "content": "You are a tabular synthetic data generation model."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.9,
            top_p = 0.95
        )

        # Get generated text data
        generated_data = response.choices[0].message.content
        i += 1
        print(f"\nGenerated Dataset {i}:\n")
        print(generated_data)

        # Parse generated text data
        lines = generated_data.strip().split('\n')

        # Filter valid rows
        valid_data = []
        current_class = None
        for line in lines:
            fields = line.split(',')
            if len(fields) == len(df.columns)-1 and fields[0] != 'Diagnosis': # Add class information to the front of the data
                valid_data.append(fields)

        # Add valid data to all_data
        all_data.extend(valid_data)
        if len(valid_data) != 0:
          valid_len.append(len(valid_data))

        print("generated data length: ", len(all_data), "/", df_len)

    # Save generated data to output file
    with open(args.output_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(df.columns)  # Write column names
        writer.writerows(all_data)  # Write data rows

    len_value = np.mean(valid_len)
    len_std = np.std(valid_len)
    
    print(f"per batch average: {len_value:.2f}% ± {len_std:.2f}%")
    print(f"Generated data saved to: {args.output_file}")

if __name__ == "__main__":
    main()


