import os
thr = 1 # Set to desired number of cores

os.environ["OMP_NUM_THREADS"] = f"{thr}"  
os.environ["OPENBLAS_NUM_THREADS"] = f"{thr}"
os.environ["MKL_NUM_THREADS"] = f"{thr}"
os.environ["VECLIB_MAXIMUM_THREADS"] = f"{thr}"
os.environ["NUMEXPR_NUM_THREADS"] = f"{thr}"

import torch
torch.set_num_threads(thr)

import numpy as np
import pandas as pd
import time
import argparse
from sdv.single_table import CTGANSynthesizer, TVAESynthesizer, CopulaGANSynthesizer, GaussianCopulaSynthesizer
from sdv.metadata import Metadata
import matplotlib.pyplot as plt
from fairlearn.datasets import fetch_acs_income

from utils import coin_toss


def main(data_name, model, trainsize=1000, samplesize=1000, epc=100):
    # === Load and preprocess data ===
    if data_name == "credit":
        df = pd.read_csv("datasets/default_credit.csv", header=0) # ground truth: 1 (default), 0 (no default)
    elif data_name == "acs":
        df = fetch_acs_income(as_frame=True).frame
    elif data_name == "adult":
        df = pd.read_csv("datasets/adult.csv", header=0)
    else:
        raise ValueError(f"Unknown dataset: {data_name}")
    
    df = df.dropna()

    # Detect metadata from the dataframe
    metadata = Metadata.detect_from_dataframe(
        data=df,
        table_name=f'{data_name}')
    metadata.save_to_json(filepath=f'datasets/{data_name}_metadata.json', mode='overwrite')

    # Do coin toss
    start_time = time.time()
    print("Performing coin toss...")
    r_bytes = coin_toss(1024, seed=42) # fix seed for reproducibility
    half = len(r_bytes) // 2
    r_torch = int.from_bytes(r_bytes[:half], byteorder='big') % (2**64 - 1)
    r_np = int.from_bytes(r_bytes[half:], byteorder='big') % (2**32 - 1)
    print(f"Coin toss took {time.time() - start_time:.2f} seconds")
    torch.manual_seed(r_torch)
    np.random.seed(r_np)

    print(f"Fitting {model} model...")
    fit_start_time = time.time()
    if model == "CTGAN":
        synthesizer = CTGANSynthesizer(metadata=metadata, epochs=epc, verbose=True)
    elif model == "TVAE":
        synthesizer = TVAESynthesizer(metadata=metadata, epochs=epc, verbose=True)
    elif model == "CopulaGAN":
        synthesizer = CopulaGANSynthesizer(metadata=metadata, epochs=epc, verbose=True)
    elif model == "GaussianCopula":
        synthesizer = GaussianCopulaSynthesizer(metadata=metadata)
    else:
        raise ValueError(f"Unknown model: {model}")
    synthesizer.fit(df.iloc[:trainsize])
    print(f"Fitting {model} took {time.time() - fit_start_time:.2f} seconds")
    synthesizer.save(f"datasets/{model}_synthesizer_{data_name}_{trainsize}_epoch{epc}.pkl")

    # Generate synthetic data
    print("Generating synthetic data...")
    gen_start_time = time.time()
    synthetic_df = synthesizer.sample(samplesize)
    print(f"Generating synthetic data took {time.time() - gen_start_time:.2f} seconds")
    synthetic_df.to_csv(f"datasets/{data_name}_syn_{model}_{trainsize}_{samplesize}_epoch{epc}.csv", index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Synthetic credit data generator")
    parser.add_argument("--data", type=str, default="credit", help="Name of the dataset (without .csv)")
    parser.add_argument("--model", type=str, default="CTGAN", help="Model to use for synthesis (CTGAN, TVAE, CopulaGAN, GaussianCopula)")
    parser.add_argument("--trainsize", type=int, default=1000, help="Number of training samples for synthesizing data")
    parser.add_argument("--samplesize", type=int, default=1000, help="Number of samples to generate")
    parser.add_argument("--epc", type=int, default=100, help="Number of epochs for training the CTGAN")
    args = parser.parse_args()
    
    main(args.data, args.model, args.trainsize, args.samplesize, args.epc)
