from programmable_synthesizer import ProgrammableSynthesizer
import argparse
import os
from utils import Timer
import pickle
import numpy as np
import torch


def main(args):

    # set the random seed
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    
    available_workloads = {
        'all_two': 2,
        'all_three': 3,
        'all_three_with_labels': 'all_three_with_labels'
    }

    epsilons = [args.epsilon] if args.single_setup else [0.1, 0.25, 0.5, 0.75, 1.0]

    base_path = f'experiment_data/dp_benchmarks/trained_models/{args.dataset}/random_seed_{args.random_seed}/ProgSyn/'
    os.makedirs(base_path, exist_ok=True)

    timer = Timer(args.n_samples * len(epsilons))
    for epsilon in epsilons:

        prompt = f"""
            SYNTHESIZE: {args.dataset};
                ENSURE: DIFFERENTIAL PRIVACY:
                    EPSILON={epsilon}, DELTA=1e-9;
            END;
        """

        for sample in range(args.n_samples):
            
            timer.start()
            print(f'Epsilon: {epsilon}    Sample: {sample+1}/{args.n_samples}    {timer}', end='\r')
            specific_path = base_path + f'trained_ProgSyn_{sample+1}_{args.n_samples}_{args.workload}_{epsilon}_{args.random_seed}.pickle'

            if os.path.isfile(specific_path) and not args.force:
                timer.total_steps -= 1
                continue
            
            synthesizer = ProgrammableSynthesizer(prompt, workload=available_workloads[args.workload], random_seed=args.random_seed)
            synthesizer.fit(force=True, save=False, verbose=False)

            with open(specific_path, 'wb') as f:
                pickle.dump(synthesizer.base_model, f)
            
            timer.end()


if __name__ == '__main__':
    parser = argparse.ArgumentParser('ps_trainer')
    parser.add_argument('--dataset', type=str, default='ADULT', help='Select the dataset')
    parser.add_argument('--n_samples', type=int, default=5, help='Number of retrain samples')
    parser.add_argument('--workload', type=str, default='all_three', help='Select the workload to train on')
    parser.add_argument('--random_seed', type=int, default=42, help='Random seed')
    parser.add_argument('--single_setup', action='store_true', help='Select to run only a single setup')
    parser.add_argument('--epsilon', type=float, default=1.0, help='Epsilon to run for single setup')
    parser.add_argument('--force', action='store_true', help='Force the training')
    in_args = parser.parse_args()
    main(in_args)
