import argparse

import pandas as pd
import torch

import wandb


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--n_samples", type=int, default=None)
    parser.add_argument("--study_name", type=str, required=True)
    parser.add_argument("--entity", type=str, required=True)
    args = parser.parse_args()

    study_name = args.study_name
    wandb.init(
        project=f"{study_name}_LLM_Generation",
        entity=args.entity,
        config={"model": args.model_name},
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Running on device: {device}")

    df = pd.read_csv(f"{study_name}/df_processed.csv")

    if args.n_samples is not None:
        df = df.sample(n=args.n_samples, random_state=42)

    from importlib import import_module
    generate_synthetic_data = import_module(f"{study_name}.generate_personas_opensource").generate_synthetic_data

    df_llm = generate_synthetic_data(df, args.model_name)

    print("Synthetic Data Generation Complete")
    df_llm.to_csv(f"{study_name}/df_{args.model_name}.csv")

    wandb_table = wandb.Table(dataframe=df_llm)
    wandb.log({"final_dataframe": wandb_table})


if __name__ == "__main__":
    main()
