import argparse


def parse_arguments():
    parser = argparse.ArgumentParser('Parse main configuration file', add_help=False)

    parser.add_argument("--dataset", default='epickitchens', type=str, choices=['synthetic', 'epickitchens', 'procthor'])
    parser.add_argument("--path_data", default='./data/epickitchens/', type=str)
    parser.add_argument("--description", default='', type=str)
    parser.add_argument("--expdir", default='./experiments/procthor/test', type=str)
    parser.add_argument("--dim_x", default=3, type=int, help='Dimensions of inputs (same as z in synthetic)')
    parser.add_argument("--num_actions", default=4, type=int, help='Number of actions (7 in procthor, 97 in epickitchens)')
    parser.add_argument("--num_objects", default=277, type=int, help='Number of objects (23 in procthor, 293 in epickitchens)')
    parser.add_argument("--select_actions", default=None, type=int)
    parser.add_argument("--resolution", default=64, type=int)
    parser.add_argument("--mean_gaussian", default=3.0, type=float)

    # Synthetic data generation
    parser.add_argument("--nature_seed", default=1, type=int, help='seed for generating groundtruth graph')
    parser.add_argument("--intervention_set", default="atomic", type=str, choices=['atomic', 'atomic_or_none', 'all'])
    parser.add_argument("--intervention_type", default="soft", type=str, choices=['soft', 'hard'])
    parser.add_argument("--dim_z", default=20, type=int, help='Number of variables in the causal graph')
    parser.add_argument("--num_components", default=3, type=int,
                        help='Number of components in the MoG of Supernode exogenous')  # 3 clusters for now: one for soft itnervention effect, one for parent, one for exogenous
    parser.add_argument("--nature_mode", default="random", type=str, choices=["random", "full", "empty", "chain", "known"],
                        help='mode for initializing graph edges')
    parser.add_argument('--nature_edges', nargs='*', type=int, default=None, help='a list of edges')
    parser.add_argument("--nature_manifold_thickness", default=1.e-9, type=float,
                        help='multiplier for noise used for the non-intervened-upon variables')
    parser.add_argument("--nature_causal_effects", default="bimodal", type=str, choices=["bimodal", "standard"],
                        help='how to initialize matrix in linear SCM')
    parser.add_argument("--samples_train", default=100000, type=int, help='Number of training samples')
    parser.add_argument("--samples_val", default=10000, type=int, help='Number of validation samples')
    parser.add_argument("--samples_test", default=10000, type=int, help='Number of test samples')
    parser.add_argument("--nature_observation_noise", default=None, type=int, help='observation_noise')

    # Model
    parser.add_argument("--model", default='softilcm', type=str, choices=['elcm', 'ilcm', 'softilcm', 'betavae', 'dvae'])
    parser.add_argument("--noise_model", default='additive', type=str, choices=['nonlinear_parents', 'multiplicative', 'additive_multiplicative', 'mixed', 'additive', 'MoG'])

    # Causal-Triplet specific Models
    parser.add_argument("--dim", default=64, type=int)
    parser.add_argument("--linear", default=False, action='store_true')
    parser.add_argument("--bbox", default=False, action='store_true')

    # SCM specific
    parser.add_argument("--scm", default='mlp', type=str, choices=['ground_truth', 'unstructured', 'mlp'])
    parser.add_argument("--scm_adjacency_matrix", default=None, type=str,
                        choices=['enco', 'dds', 'fixed_order', None, "none_fixed_order", "none", "none_trivial"])
    parser.add_argument("--scm_manifold_thickness", default=0.01, type=float, help='')
    parser.add_argument("--scm_hidden_units", default=64, type=int)
    parser.add_argument("--scm_hidden_layers", default=2, type=int)
    parser.add_argument("--scm_homoskedastic", default=False, action='store_true')
    parser.add_argument("--scm_min_std", default=0.2, type=float, help='')
    parser.add_argument("--scm_init", default='broad', type=str, choices=['default', 'strong_effects', 'broad'])
    parser.add_argument("--var_diminish", default=1.0, type=float, help='diminishing factor distribution variance')

    # Encoder/Decoder specific
    parser.add_argument("--encoder", default='resnet50', type=str, choices=['mlp', 'clip', 'resnet18', 'resnet50', 'group'])
    parser.add_argument("--encoder_hidden_layers", default=2, type=int)
    parser.add_argument("--encoder_hidden_units", default=64, type=int)
    parser.add_argument("--encoder_fix_std", default=False, action='store_true',
                        help='whether to fix std or make it learnable')
    parser.add_argument("--encoder_std", default=0.01, type=float, help='std for initialization')
    parser.add_argument("--encoder_min_std", default=0.001, type=float, help='std for initialization')

    parser.add_argument("--decoder_hidden_layers", default=2, type=int)
    parser.add_argument("--decoder_hidden_units", default=64, type=int)
    parser.add_argument("--decoder_fix_std", default=False, action='store_true')
    parser.add_argument("--decoder_std", default=0.1, type=float)
    parser.add_argument("--decoder_min_std", default=0.001, type=float)
    parser.add_argument("--averaging_strategy", default='stochastic', type=str, choices=['stochastic', 'average', 'z2'],
                        help='')
    parser.add_argument("--amin", default=0.0, type=float, help='min value for attention map')
    parser.add_argument("--mask", default=False, action='store_true')

    # Intervention encoder specific
    parser.add_argument("--intervention_encoder", default='learnable_heuristic', type=str,
                        choices=['learnable_heuristic', 'mlp'])
    parser.add_argument("--intervention_encoder_hidden_layers", default=2, type=int)
    parser.add_argument("--intervention_encoder_hidden_units", default=64, type=int)

    # Graph learning setting
    parser.add_argument("--graph_sampling_initial_unfreeze_epoch", default=0, type=int)
    parser.add_argument("--graph_sampling_mode", default="deterministic", type=str,
                        choices=['deterministic', 'hard', 'soft'])
    parser.add_argument("--graph_sampling_temperature", default=1.0, type=float)
    parser.add_argument("--graph_sampling_samples", default=1, type=int)
    parser.add_argument("--graph_sampling_final_freeze_epoch", default=1000, type=int)
    parser.add_argument("--graph_sampling_final_mode", default="deterministic", type=str,
                        choices=['deterministic', 'hard', 'soft'])
    parser.add_argument("--graph_sampling_final_temperature", default=1.0, type=float)
    parser.add_argument("--graph_sampling_final_samples", default=1, type=int)

    ### Training
    # Learning rate Scheduler
    parser.add_argument("--lr_schedule", default="cosine", type=str,
                        choices=['constant', 'cosine', 'cosine_restarts', 'cosine_restarts_reset', 'step'])
    parser.add_argument("--lr", default=3e-4, type=float)
    parser.add_argument("--lr_schedule_minimal", default=1.e-8, type=float)  # for all except step
    parser.add_argument("--lr_schedule_increase_period_by_factor", default=1, type=int)  # for cosine_restarts
    parser.add_argument("--lr_schedule_restart_every_epochs", default=30, type=int)  # for cosine_restarts
    parser.add_argument("--lr_schedule_step_every_epochs", default=0, type=int)  # for step
    parser.add_argument("--lr_schedule_step_gamma", default=0.1, type=float)  # for step

    # Step Scheduler
    parser.add_argument("--manifold_thickness_schedule", default="constant", type=str,
                        choices=['constant', 'constant_constant', 'exponential', 'exponential_constant',
                                 'constant_exponential_constant', 'constant_linear_constant'])
    parser.add_argument("--manifold_thickness_schedule_initial", default=0.01, type=float)
    parser.add_argument("--manifold_thickness_schedule_final", default=0.01, type=float)
    parser.add_argument("--manifold_thickness_schedule_initial_constant_epochs", default=5, type=int)
    parser.add_argument("--manifold_thickness_schedule_decay_epochs", default=45, type=int)

    parser.add_argument("--beta_schedule", default="constant_linear_constant", type=str,
                        choices=['constant', 'constant_constant', 'exponential', 'exponential_constant',
                                 'constant_exponential_constant', 'constant_linear_constant'])
    parser.add_argument("--beta_schedule_initial", default=0.0, type=float)
    parser.add_argument("--beta_schedule_final", default=1.0, type=float)
    parser.add_argument("--beta_schedule_initial_constant_epochs", default=0, type=int)
    parser.add_argument("--beta_schedule_decay_epochs", default=10, type=int)

    parser.add_argument("--increase_intervention_beta", default=1.0, type=float)

    parser.add_argument("--z_regularization_schedule", default="constant_linear_constant", type=str,
                        choices=['constant', 'constant_constant', 'exponential', 'exponential_constant',
                                 'constant_exponential_constant', 'constant_linear_constant'])
    parser.add_argument("--z_regularization_schedule_initial", default=0.0, type=float)
    parser.add_argument("--z_regularization_schedule_final", default=0.0, type=float)
    parser.add_argument("--z_regularization_schedule_initial_constant_epochs", default=0, type=int)
    parser.add_argument("--z_regularization_schedule_decay_epochs", default=0, type=int)

    parser.add_argument("--edge_regularization_schedule", default="constant", type=str,
                        choices=['constant', 'constant_constant', 'exponential', 'exponential_constant',
                                 'constant_exponential_constant', 'constant_linear_constant'])
    parser.add_argument("--edge_regularization_schedule_initial", default=0.0, type=float)
    parser.add_argument("--edge_regularization_schedule_final", default=0.0, type=float)
    parser.add_argument("--edge_regularization_schedule_initial_constant_epochs", default=0, type=int)
    parser.add_argument("--edge_regularization_schedule_decay_epochs", default=0, type=int)

    parser.add_argument("--cyclicity_regularization_schedule", default="constant", type=str,
                        choices=['constant', 'constant_constant', 'exponential', 'exponential_constant',
                                 'constant_exponential_constant', 'constant_linear_constant'])
    parser.add_argument("--cyclicity_regularization_schedule_initial", default=0.0, type=float)
    parser.add_argument("--cyclicity_regularization_schedule_final", default=0.0, type=float)
    parser.add_argument("--cyclicity_regularization_schedule_initial_constant_epochs", default=0, type=int)
    parser.add_argument("--cyclicity_regularization_schedule_decay_epochs", default=0, type=int)

    parser.add_argument("--consistency_regularization_schedule", default="constant", type=str,
                        choices=['constant', 'constant_constant', 'exponential', 'exponential_constant',
                                 'constant_exponential_constant', 'constant_linear_constant'])
    parser.add_argument("--consistency_regularization_schedule_initial", default=0.0, type=float)
    parser.add_argument("--consistency_regularization_schedule_final", default=0.0, type=float)
    parser.add_argument("--consistency_regularization_schedule_initial_constant_epochs", default=0, type=int)
    parser.add_argument("--consistency_regularization_schedule_decay_epochs", default=0, type=int)

    parser.add_argument("--inverse_consistency_regularization_schedule", default="constant", type=str,
                        choices=['constant', 'constant_constant', 'exponential', 'exponential_constant',
                                 'constant_exponential_constant', 'constant_linear_constant'])
    parser.add_argument("--inverse_consistency_regularization_schedule_initial", default=0.0, type=float)
    parser.add_argument("--inverse_consistency_regularization_schedule_final", default=0.0, type=float)
    parser.add_argument("--inverse_consistency_regularization_schedule_initial_constant_epochs", default=0, type=int)
    parser.add_argument("--inverse_consistency_regularization_schedule_decay_epochs", default=0, type=int)

    parser.add_argument("--intervention_entropy_regularization_schedule", default="constant", type=str,
                        choices=['constant', 'constant_constant', 'exponential', 'exponential_constant',
                                 'constant_exponential_constant', 'constant_linear_constant'])
    parser.add_argument("--intervention_entropy_regularization_schedule_initial", default=0.0, type=float)
    parser.add_argument("--intervention_entropy_regularization_schedule_final", default=0.0, type=float)
    parser.add_argument("--intervention_entropy_regularization_schedule_initial_constant_epochs", default=0, type=int)
    parser.add_argument("--intervention_entropy_regularization_schedule_decay_epochs", default=0, type=int)

    parser.add_argument("--intervention_encoder_offset_schedule", default="constant_exponential_constant", type=str,
                        choices=['constant', 'constant_constant', 'exponential', 'exponential_constant',
                                 'constant_exponential_constant', 'constant_linear_constant'])
    parser.add_argument("--intervention_encoder_offset_schedule_initial", default=0.0, type=float)
    parser.add_argument("--intervention_encoder_offset_schedule_final", default=0.0, type=float)
    parser.add_argument("--intervention_encoder_offset_schedule_initial_constant_epochs", default=0, type=int)
    parser.add_argument("--intervention_encoder_offset_schedule_decay_epochs", default=0, type=int)

    # epoch scheduler
    parser.add_argument("--epochs", default=500, type=int)
    parser.add_argument("--start_epoch", default=0, type=int)
    parser.add_argument("--pretrain_epochs", default=None, type=int)
    parser.add_argument("--model_interventions_after_epoch", default=None, type=int)
    parser.add_argument("--fix_topological_order_epoch", default=None, type=int)
    parser.add_argument("--model_noise_after_epoch", default=None, type=int)
    parser.add_argument("--freeze_encoder_epoch", default=1000, type=int)
    parser.add_argument("--deterministic_intervention_encoder_after_epoch", default=None, type=int)
    parser.add_argument("--pretrain_beta", default=0.001, type=float)
    parser.add_argument("--full_likelihood", default=False, action='store_true')
    parser.add_argument("--adversarial", default=False, action='store_true')
    parser.add_argument("--likelihood_reduction", default='sum', type=str, choices=[])
    parser.add_argument("--clip_grad_norm", default=None, type=float)
    parser.add_argument("--early_stopping_var", default='loss', type=str, choices=['nll', 'loss'])

    # Eval settings
    parser.add_argument("--eval_full_likelihood", default=False, action='store_true')
    parser.add_argument("--eval_likelihood_reduction", default='sum', type=str)
    parser.add_argument("--eval_graph_sampling", default='deterministic', type=str)
    parser.add_argument("--eval_graph_sampling_temperature", default=1.0, type=float)
    parser.add_argument("--eval_graph_sampling_samples", default=1, type=int)
    parser.add_argument("--enco_lambda", default=0.01, type=float)
    parser.add_argument("--iwae_samples", default=10, type=int)
    parser.add_argument("--eval_beta", default=1.0, type=float)
    parser.add_argument("--val_every_epoch", default=10, type=int)

    # setting
    parser.add_argument("--seed", default=2022, type=int)

    parser.add_argument("--ood", default='noun', type=str)
    parser.add_argument("--translation", default=0.0, type=float)
    parser.add_argument("--train_size", default=5000, type=int, help='size of training data')
    parser.add_argument("--num_samples", default=1, type=int, help='Number of samples used to approximate expectations')


    # loader
    parser.add_argument("--num_workers", default=8, type=int)
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--gpu", default=0, type=int)
    parser.add_argument("--distributed", default=False, action='store_true')

    # train
    parser.add_argument("--finetune", default=False, action='store_true')
    parser.add_argument("--ckpt", default=None, type=str)
    parser.add_argument("--critic_state", default=0.0, type=float)
    parser.add_argument("--sparse", default=0.0, type=float)
    parser.add_argument("--eval", default='quantitative', type=str, choices=['quantitative', 'qualitative'])

    # log
    parser.add_argument('--print_freq', default=10, type=int, help='print frequency')

    return parser.parse_args()
