#!/usr/bin/env python
# # -*- coding: utf-8 -*-

import numpy as np 
import os 
import multiprocessing as mp
import argparse


# Network params
class NParams():
    def __init__(self, gen_circuit_type, n_qubits, n_ancilla, n_layers, n_diff_steps):
        self.gen_circuit_type = gen_circuit_type
        self.n_qubits = n_qubits
        self.n_ancilla = n_ancilla
        self.n_layers = n_layers
        self.n_diff_steps = n_diff_steps

# Data params
class DParams():
    def __init__(self, dat_name, input_type, n_train, n_test):
        self.dat_name = dat_name
        self.input_type = input_type
        self.n_train = n_train
        self.n_test = n_test

# Optimizer params
class OParams():
    def __init__(self, lr, mag, n_outer_epochs, batch_size, dist_type, vendi_lambda):        
        self.lr = lr
        self.mag = mag
        self.n_outer_epochs = n_outer_epochs
        self.batch_size = batch_size
        self.dist_type = dist_type
        self.vendi_lambda = vendi_lambda


def execute_job(bin, nparams, dparams, oparams, save_dir, load_params, rseed, plot_bloch, n_threads):
    print(f'Start process with rseed={rseed}')
    cmd = f'python {bin} \
            --n_qubits {nparams.n_qubits} \
            --n_ancilla {nparams.n_ancilla} \
            --n_layers {nparams.n_layers} \
            --n_diff_steps {nparams.n_diff_steps} \
            --gen_circuit_type {nparams.gen_circuit_type} \
            --dat_name {dparams.dat_name} \
            --input_type {dparams.input_type} \
            --n_train {dparams.n_train} \
            --n_test {dparams.n_test} \
            --n_outer_epochs {oparams.n_outer_epochs} \
            --batch_size {oparams.batch_size} \
            --dist_type {oparams.dist_type} \
            --lr {oparams.lr} \
            --mag {oparams.mag} \
            --vendi_lambda {oparams.vendi_lambda} \
            --rseed {rseed} \
            --bloch {plot_bloch} \
            --load_params {load_params} \
            --threads {n_threads} \
            --save_dir {save_dir}'
    os.system(cmd)
    print(f'Finish process with rseed={rseed}')

    
if __name__ == '__main__':
    #mp.set_start_method("spawn")
    # Check for command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_dir', type=str, default='../results/del_test2')
    parser.add_argument('--bin', type=str, default='../source/main_gen_demo.py')

    parser.add_argument('--load_params', type=str, default='params_file')

    # For network
    parser.add_argument('--n_layers', type=str, default='10', help='number layers for generator')
    parser.add_argument('--n_diff_steps', type=int, default=10, help='number steps for diffusion')


    # For data
    parser.add_argument('--dat_name', type=str, default='cluster0', help='name of the data: cluster0, line')
    parser.add_argument('--input_type', type=str, default='rand', help='type of the input')
    parser.add_argument('--n_qubits', type=int, default=1, help='Number of data qubits')
    parser.add_argument('--n_ancilla', type=int, default=1, help='Number of ancilla qubits')
    
    parser.add_argument('--n_train', type=int, default=100, help='Number of training data')
    parser.add_argument('--n_test', type=int, default=100, help='Number of test data')

    # For epoch in training
    parser.add_argument('--n_outer_epochs', type=int, default=1000, help='Number of outer training epoch')
    parser.add_argument('--batch_size', type=int, default=100, help='Batch size for training')
    parser.add_argument('--dist_type', type=str, default='wass', help='Type of distance: wass or mmd')

    # For update matrix
    parser.add_argument('--lr', type=str, default='0.001', help='Learning rate for generator')
    parser.add_argument('--mag', type=str, default='1.0', help='Magnitude of initial parameters')
    parser.add_argument('--vendi_lambda', type=str, default='0.0', help='Vendi loss lambda')
    #parser.add_argument('--ep', type=float, default=0.01, help='Step size')

    # For gen circuit type
    parser.add_argument('--gen_circuit_type', type=str, default='rxycz', help='type of generator circuit')

    # For system
    parser.add_argument('--rseed', type=str, default='0', help='Random seed')
    parser.add_argument('--bloch', type=int, default=0, help='Plot bloch')
    parser.add_argument('--threads', type=int, default=1, help='Number of threads')

    args = parser.parse_args()

    save_dir = args.save_dir
    n_train, n_test, n_outer_epochs, batch_size = args.n_train, args.n_test, args.n_outer_epochs, args.batch_size
    n_qubits, n_ancilla, rseed = args.n_qubits, args.n_ancilla, args.rseed
    load_params = args.load_params
    n_diff_steps = args.n_diff_steps
    plot_bloch, n_threads = args.bloch, args.threads

    gen_circuit_type = args.gen_circuit_type
    input_type_ls = [str(x) for x in args.input_type.split(',')]
    n_layers_ls = [int(x) for x in args.n_layers.split(',')]
    dat_names = [str(x) for x in args.dat_name.split(',')]
    rseeds = [int(x) for x in args.rseed.split(',')]
    lrs = [float(x) for x in args.lr.split(',')]
    mags = [float(x) for x in args.mag.split(',')]
    vendi_lambdas = [float(x) for x in args.vendi_lambda.split(',')]
    dist_types = [str(x) for x in args.dist_type.split(',')]


    args_list = []
    for dat_name in dat_names:
        for input_type in input_type_ls:
            dparams = DParams(dat_name, input_type, n_train, n_test)
            for n_layers in n_layers_ls:
                # Create the network parameters
                nparams = NParams(
                    gen_circuit_type=gen_circuit_type,
                    n_qubits=n_qubits,
                    n_ancilla=n_ancilla,
                    n_layers=n_layers,
                    n_diff_steps=n_diff_steps
                )
                for lr in lrs:
                    for mag in mags:
                        for vendi_lambda in vendi_lambdas:
                            for dist_type in dist_types:
                                oparams = OParams(
                                    lr=lr,
                                    mag=mag,
                                    n_outer_epochs=n_outer_epochs,
                                    batch_size=batch_size,
                                    dist_type=dist_type,
                                    vendi_lambda=vendi_lambda
                                )
                                for rseed in rseeds:
                                    args_list.append((
                                        args.bin,
                                        nparams,
                                        dparams,
                                        oparams,
                                        save_dir,
                                        load_params,
                                        rseed,
                                        plot_bloch,
                                        n_threads
                                    ))

    n_workers = min(mp.cpu_count(), len(args_list))

    with mp.Pool(processes=n_workers) as pool:
        # starmap will call execute_job(*args_tuple) for each entry in job_args
        pool.starmap(execute_job, args_list, chunksize=1)
    

    # jobs = []
    # for dat_name in dat_names:
    #     for input_type in input_type_ls:
    #         dparams = DParams(dat_name, input_type, n_train, n_test)
    #         for n_layers in n_layers_ls:
    #             for scramb in scramb_ls:
    #                 for n_pj_qubits in n_pj_ls:
    #                     for type_evol in type_evol_ls:
    #                         if scramb == 'random' and type_evol == 'trotter':
    #                             print(f'Skip {scramb}-{type_evol}')
    #                         else:
    #                             for delta_t in delta_ls:
    #                                 nparams = NParams(gen_circuit_type = gen_circuit_type, n_qubits=n_qubits, n_ancilla=n_ancilla, n_layers=n_layers, n_diff_steps=n_diff_steps, \
    #                                                 scramb=scramb, type_evol=type_evol, n_pj_qubits=n_pj_qubits, delta_t=delta_t)
    #                                 for lr in lrs:
    #                                     for dist_type in dist_types:
    #                                         oparams = OParams(lr=lr, n_outer_epochs=n_outer_epochs, dist_type=dist_type)
    #                                         for rseed in rseeds:
    #                                             p = mp.Process(target=execute_job, args=(args.bin, nparams, dparams, oparams, save_dir, load_params, rseed, plot_bloch, n_threads))
    #                                             jobs.append(p)
        
    # # Start the process
    # for p in jobs:
    #     p.start()

    # # Ensure all processes have finished execution
    # for p in jobs:
    #     p.join()