import os
import pickle
import argparse
import itertools
import numpy as np
from utils import create_directory
from gene_expression_experiments.simulator import Simulator
from gene_expression_experiments.tbr_experiment import train_tbr
from gene_expression_experiments.baseline_experiment import train_baseline


dpf_list = [0.0, 0.1, 0.01, 0.001, 0.0001]
lr_list = [0.001, 0.0001]
sparsity_list = [0, 1, 2]
seeds = 30
weight_decay = 0.0
batchnorm = False

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output-dir', type=str, default='gtex_experiment')
    parser.add_argument('--sc-file', type=str, default='data/processed_gtex.h5ad')
    parser.add_argument('--tf-dict', type=str, default='data/tf_dict_38183')
    args = parser.parse_args()
    base_output_dir_temp = args.output_dir

    if batchnorm:
        print('Batchnorm')

    time_vec = list()

    z_candidates = ['C1orf43', 'CHMP2A', 'EMC7', 'PSMB2', 'PSMB4', 'REEP5', 'SNRPD3', 'VCP', 'VPS29']
    z_pairs = [i for i in itertools.combinations(z_candidates, 5)]
    num_experiments = len(z_pairs)  # reduce as desired based on available runtime
    idx = np.random.choice(len(z_pairs), num_experiments, replace=False)
    for i in idx:
        z_genes = list(z_pairs[i])
        simulator = Simulator(z_genes=z_genes, sc_file=args.sc_file, tf_file=args.tf_dict, regress_out=False)
        data_dict = simulator.data_dict
        base_output_dir = '{}_{}'.format(base_output_dir_temp, '_'.join(z_genes))
        create_directory(base_output_dir, remove_curr=False)
        with open(os.path.join(base_output_dir, "data_dict"), 'wb') as f:
            pickle.dump([data_dict], f)

        for sp in sparsity_list:
            for s in range(seeds):
                y = simulator.simulate(k_sparse=sp, seed=s)
                # running TBR settings
                for dpf in dpf_list:
                    for lr in lr_list:
                        output_dir = '{}/tbr_tuning_rep_{}_{}_{}_{}'.format(base_output_dir, s, dpf, lr, sp)
                        train_tbr(data_dict, y, lr, weight_decay, dpf, output_dir=output_dir, batchnorm=batchnorm)
                # running the baseline
                for lr in lr_list:
                    output_dir = '{}/baseline_tuning_rep_{}_{}_{}'.format(base_output_dir, s, lr, sp)
                    train_baseline(data_dict, y, lr, weight_decay, output_dir=output_dir, batchnorm=batchnorm)

