from programmable_synthesizer import ProgrammableSynthesizer
import argparse
import os
from utils import Timer
import pickle
import numpy as np
import torch


def main(args):

    # set the random seed
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    
    available_workloads = {
        'all_two': 2,
        'all_three': 3,
        'all_three_with_labels': 'all_three_with_labels'
    }

    prompt = f"SYNTHESIZE: {args.dataset}; END;"

    base_path = f'experiment_data/non_dp_benchmarks/trained_models/{args.dataset}/random_seed_{args.random_seed}/ProgSyn/'
    os.makedirs(base_path, exist_ok=True)

    timer = Timer(args.n_samples)
    for sample in range(args.n_samples):
        
        timer.start()
        print(f'Sample: {sample+1}/{args.n_samples}    {timer}', end='\r')
        specific_path = base_path + f'trained_ProgSyn_{sample+1}_{args.n_samples}_{args.workload}_{args.random_seed}.pickle'

        if os.path.isfile(specific_path) and not args.force:
            timer.total_steps -= 1
            continue
        
        synthesizer = ProgrammableSynthesizer(prompt, workload=available_workloads[args.workload], random_seed=args.random_seed)
        synthesizer.fit(force=True, save=False)

        with open(specific_path, 'wb') as f:
            pickle.dump(synthesizer.base_model, f)
        
        timer.end()


if __name__ == '__main__':
    parser = argparse.ArgumentParser('ps_trainer')
    parser.add_argument('--dataset', type=str, default='ADULT', help='Select the dataset')
    parser.add_argument('--n_samples', type=int, default=5, help='Number of retrain samples')
    parser.add_argument('--workload', type=str, default='all_three_with_labels', help='Select the workload to train on')
    parser.add_argument('--random_seed', type=int, default=42, help='Random seed')
    parser.add_argument('--force', action='store_true', help='Force the training')
    in_args = parser.parse_args()
    main(in_args)
