from utils import run_sbatch, dict2options
import argparse

def remove_retrain():
    settings = {
        'dataset_name': ['heloc_noisy'],
        'model_name': ['ann_m'],
        'attributor_name': ['influence_cg', 'influence_lissa', 'influence_gn', 'tracin', 'random'],
        'grouping': ['grad_kmeans'],
        'invert': [True, False],
        'n_step': [64],
    }
    extra_flags = '--use_model_cache '
    options_str, name_str = dict2options(settings)
    for (opt, name) in zip(options_str, name_str):
        run_sbatch('experiments/remove_retrain_pipeline.py',
                   options=opt+extra_flags, job_name=name,
                   log_file=f'experiments/results/remove_retrain/logs/{name}', time_hrs='12')
        print(opt + extra_flags)
        print(f'{name}')

def compute_clusters():
    settings = {
        'dataset_name': ['qnli_noisy'],  # nb: only designed for single dataset use
        'model_name': ['bert'],
        'grouping': ['equal', 'grad_kmeans', 'repr_kmeans'],
        'n_trials': [10],
    }
    name_start = settings['dataset_name'][0].split('_')[0]
    gpu = True if name_start in ['qnli', 'cifar10', 'mnist'] else False
    gpu_name = 'nvidia_h100_80gb_hbm3' if name_start == 'qnli' else None
    extra_flags = ' -c' if gpu else ''
    options_str, name_str = dict2options(settings)
    for (opt, name) in zip(options_str, name_str):
        dataset_name = opt.split('dataset_name=')[1].split(' ')[0]
        model_name = opt.split('model_name=')[1].split(' ')[0]
        name = name.split(model_name)[1].lstrip('_')
        log_folder = f'experiments/results/clusters/{dataset_name}_{model_name}/logs'
        run_sbatch('experiments/compute_clusters.py',
                   options=opt+extra_flags, job_name=f'c_{dataset_name}_{model_name}_{name}',
                   log_file=f'{log_folder}/{name}', time_hrs='4',
                   gpu=gpu, gpu_name=gpu_name)
        print(opt + extra_flags)
        print(f'{name}')

def attr_scores():
    settings = {
        'dataset_name': ['qnli', 'qnli_noisy'],  # nb: only designed for single dataset use
        'model_name': ['bert'],
        'attributor_name': ['influence_lissa'],
        'grouping': ['equal', 'grad_kmeans', 'repr_kmeans'],
        'n_trials': [3],
    }
    name_start = settings['dataset_name'][0].split('_')[0]
    gpu = True if name_start in ['qnli', 'cifar10', 'mnist'] else False
    gpu_name = 'nvidia_h100_80gb_hbm3' if name_start == 'qnli' else None
    extra_flags = ' -c' if gpu else ''
    options_str, name_str = dict2options(settings)
    for (opt, name) in zip(options_str, name_str):
        dataset_name = opt.split('dataset_name=')[1].split(' ')[0]
        model_name = opt.split('model_name=')[1].split(' ')[0]
        name = name.split(model_name)[1].lstrip('_')
        log_folder = f'experiments/results/attr_scores/{dataset_name}_{model_name}/logs'
        run_sbatch('experiments/attribute_scores.py',
                   options=opt+extra_flags, job_name=f'a_{dataset_name}_{model_name}_{name}_5',
                   log_file=f'{log_folder}/{name}_5', time_hrs='8',
                   gpu=gpu, gpu_name=gpu_name)
        print(opt + extra_flags)
        print(f'{name}')

def retrain_outputs():
    settings = {
        'dataset_name': ['qnli'],  # nb: only designed for single dataset use
        'model_name': ['bert'],
        'attributor_name': ['influence_lissa'],
        'grouping': ['equal', 'grad_kmeans', 'repr_kmeans'],
        'invert': [False, True],
        'group_size': [1, 4, 16, 64, 256, 1024],
        'n_trials': [3],
    }
    name_start = settings['dataset_name'][0].split('_')[0]
    gpu = True if name_start in ['qnli', 'cifar10', 'mnist'] else False
    gpu_name = 'nvidia_h100_80gb_hbm3' if name_start in ['qnli', 'cifar10'] else None
    extra_flags = ' -c' if gpu else ''
    options_str, name_str = dict2options(settings)
    for (opt, name) in zip(options_str, name_str):
        dataset_name = opt.split('dataset_name=')[1].split(' ')[0]
        model_name = opt.split('model_name=')[1].split(' ')[0]
        name = name.split(model_name)[1].lstrip('_')
        log_folder = f'experiments/results/retrain_outputs/{dataset_name}_{model_name}/logs'
        run_sbatch('experiments/retrain_outputs.py',
                   options=opt+extra_flags, job_name=f'r_{dataset_name}_{model_name}_{name}',
                   log_file=f'{log_folder}/{name}', time_hrs='6',
                   gpu=gpu, gpu_name=gpu_name)
        print(opt + extra_flags)
        print(f'{name}')

def test():
    settings = {
        'dataset_name': ['qnli', 'qnli_noisy'],
        'model_name': ['bert'],
        'learning_rate': [2e-5],
        'num_epochs': [3],
        'weight_decay': [0],
        'batch_size': [512],
        'seed': [0, 1, 2],
    }
    name_start = settings['dataset_name'][0].split('_')[0]
    gpu = True if name_start in ['qnli', 'cifar10', 'mnist'] else False
    gpu_name = 'nvidia_h100_80gb_hbm3' if name_start == 'qnli' else None
    extra_flags = ' -c' if gpu else ''
    options_str, name_str = dict2options(settings)
    for (opt, name) in zip(options_str, name_str):
        run_sbatch('experiments/test.py',
                   options=opt+extra_flags, job_name=f'{name}',
                   log_file=f'experiments/test/logs/{name}', time_hrs='4',
                   gpu=gpu, gpu_name=gpu_name)
        print(opt + extra_flags)
        print(f'{name}')

def compute_properties():
    settings = {
        'dataset_name': ['heloc'],  # nb: only designed for single dataset use
        'model_name': ['lr', 'ann_s', 'ann_m'],
        'property_name': ['accuracy'],
        'attributor_name': ['trak'],
        'grouping': ['equal', 'kmeans', 'grad_kmeans', 'repr_kmeans'],
        'invert': [True, False],
        'n_trials': [1],
    }
    name_start = settings['dataset_name'][0].split('_')[0]
    gpu = True if name_start in ['qnli', 'cifar10', 'mnist'] else False
    gpu_name = 'nvidia_h100_80gb_hbm3' if name_start == 'qnli' else None
    extra_flags = ' -c' if gpu else ''
    options_str, name_str = dict2options(settings)
    for (opt, name) in zip(options_str, name_str):
        dataset_name = opt.split('dataset_name=')[1].split(' ')[0]
        model_name = opt.split('model_name=')[1].split(' ')[0]
        name = name.split(model_name)[1].lstrip('_')
        log_folder = f'experiments/results/properties/{dataset_name}_{model_name}/logs'
        run_sbatch('experiments/compute_properties.py',
                   options=opt+extra_flags, job_name=f'p_{dataset_name}_{model_name}_{name}',
                   log_file=f'{log_folder}/{name}', time_hrs='1',
                   gpu=gpu, gpu_name=gpu_name)
        print(opt + extra_flags)
        print(f'{name}')

if __name__ == "__main__":
    # Add argument parser for remove_retrain and attr_scores
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', '--experiment', type=str, default='attr_scores',
                        choices=['remove_retrain', 'compute_clusters', 'attr_scores',
                                 'retrain_outputs', 'test', 'compute_properties'],
                        help='Experiment to run')
    args = parser.parse_args()

    print(f'Running {args.experiment} experiment')

    if args.experiment == 'remove_retrain':
        remove_retrain()
    elif args.experiment == 'compute_clusters':
        compute_clusters()
    elif args.experiment == 'attr_scores':
        attr_scores()
    elif args.experiment == 'retrain_outputs':
        retrain_outputs()
    elif args.experiment == 'test':
        test()
    elif args.experiment == 'compute_properties':
        compute_properties()