import os
import pickle
import argparse
import itertools
import numpy as np
from utils import create_directory
from gene_expression_experiments.simulator import Simulator
from gene_expression_experiments.tbr_experiment import train_tbr
from gene_expression_experiments.baseline_experiment import train_baseline


dpf_list = [0.0, 0.1, 0.01, 0.001, 0.0001]
lr_list = [0.001, 0.0001]
sparsity_list = [0, 1, 2]
seeds = 30
weight_decay = 0.0
batchnorm = False

num_test_types = 3
num_test_rounds = 5

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output-dir', type=str, default='gtex_transfer_test_experiment')
    parser.add_argument('--sc-file', type=str, default='data/processed_gtex.h5ad')
    parser.add_argument('--tf-dict', type=str, default='data/tf_dict_38183')
    args = parser.parse_args()
    base_output_dir_temp = args.output_dir

    if batchnorm:
        print('Batchnorm')

    time_vec = list()

    z_candidates = ['C1orf43', 'CHMP2A', 'EMC7', 'PSMB2', 'PSMB4', 'REEP5', 'SNRPD3', 'VCP', 'VPS29']
    z_pairs = [i for i in itertools.combinations(z_candidates, 5)]
    num_experiments = 10
    idx = np.random.choice(len(z_pairs), num_experiments, replace=False)
    for i in idx:
        z_genes = list(z_pairs[i])
        simulator = Simulator(z_genes=z_genes, sc_file=args.sc_file, tf_file=args.tf_dict, regress_out=False)
        data_dict = simulator.data_dict
        base_output_dir = '{}_{}'.format(base_output_dir_temp, '_'.join(z_genes))
        create_directory(base_output_dir, remove_curr=False)
        with open(os.path.join(base_output_dir, "data_dict"), 'wb') as f:
            pickle.dump([data_dict], f)

        # selecting out the cell types for transfer learning
        for _ in range(num_test_rounds):
            test_types = np.random.choice(list(set(data_dict['cell_types'])), size=num_test_types)

            for sp in sparsity_list:
                for s in range(seeds):
                    y = simulator.simulate(k_sparse=sp, seed=s)
                    # running TBR settings
                    for dpf in dpf_list:
                        for lr in lr_list:
                            print('Sparsity: {}'.format(sp))
                            output_dir = '{}/tbr_tuning_rep_{}_{}_{}_{}_{}'.format(base_output_dir, s, dpf, lr, sp, '_'.join(test_types))
                            train_tbr(data_dict, y, lr, weight_decay, dpf, output_dir=output_dir, batchnorm=batchnorm, test_types=test_types)

                    # running the baseline
                    for lr in lr_list:
                        output_dir = '{}/baseline_tuning_rep_{}_{}_{}_{}'.format(base_output_dir, s, lr, sp, '_'.join(test_types))
                        train_baseline(data_dict, y, lr, weight_decay, output_dir=output_dir, batchnorm=batchnorm, test_types=test_types)