import argparse

from src import Burgers1D_CRunner, Burgers2D_CRunner
from src import Poisson2D_CRunner, Poisson2D_CGRunner, \
    Poisson3D_CGRunner, Poisson2D_MSRunner
from src import Heat2D_VCRunner, Heat2D_MSRunner, \
    Heat2D_CGRunner, Heat2D_LTRunner
from src import Wave1D_CRunner, Wave2D_CGRunner, Wave2D_MSRunner
from src import NS2D_CRunner, NS2D_CGRunner, NS2D_LTRunner
from src import GSRunner, KSRunner
from src import PInvRunner, HInvRunner

from utils.parallel import Server

parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="Test")
parser.add_argument("--path", type=str, default=None)
parser.add_argument("--device", type=str, default="Auto")
parser.add_argument("--repeat", type=int, default=1)
command_args = parser.parse_args()

runners = [
    Burgers1D_CRunner,
    Poisson2D_CRunner, Poisson2D_CGRunner, Poisson3D_CGRunner, Poisson2D_MSRunner,
    Heat2D_VCRunner, Heat2D_MSRunner, Heat2D_CGRunner, 
    NS2D_CRunner, NS2D_CGRunner,
    Wave1D_CRunner, Wave2D_CGRunner, Wave2D_MSRunner,
    Heat2D_LTRunner, NS2D_LTRunner,
    GSRunner,
    Burgers2D_CRunner,
    KSRunner,
    PInvRunner, HInvRunner
]

def job(runner, **kwargs):
    runner = runners[list(map((lambda r:r.__name__), runners)).index(runner)]()
    runner.load_config()
    config = runner.get_config()
    config["Round Number"] = kwargs['repeatid']
    config["Total Rounds"] = serv.repeat

    if kwargs.get('ablation'): # ablation study experiments
        config["Use Preconditioner"] = kwargs['use_precondition']
        if kwargs['use_precondition']:
            config["Drop Tolerance (in ILU)"] = kwargs['drop_tolerance']
    
    if kwargs.get('timeexp'): # time experiment
        config["Grid Size (mesh length)"] = kwargs['mesh_len']
        config["Drop Tolerance (in ILU)"] = 0.1
        config["Use Sparse Solver"] = True

    runner.setup()
    runner.run()

if __name__ == "__main__":
    serv = Server(exp_name=command_args.name, path=command_args.path, \
                  device=command_args.device, repeat=command_args.repeat)
    serv.set_default_task({"target": job})

    def main_experiment():
        for runner in runners:
            serv.add_task({"runner": runner.__name__})
    
    def ablation_experiment():
        for runner in [Poisson2D_CRunner, Poisson2D_CGRunner, Poisson3D_CGRunner, Poisson2D_MSRunner]:
            for tol in [1e-4, 1e-3, 1e-2, 1e-1]:
                serv.add_task({
                    "runner": runner.__name__,
                    "ablation": True,
                    "use_precondition": True,
                    "drop_tolerance": tol
                })
            
            serv.add_task({
                "runner": runner.__name__,
                "ablation": True,
                "use_precondition": False,
                "drop_tolerance": None
            })

    def main_experiment_timed():
        for runner in [Burgers1D_CRunner, Heat2D_VCRunner, NS2D_CRunner, Wave1D_CRunner]:
            serv.add_task({"runner": runner.__name__})

    def time_mesh_experiment():
        for mesh_len in [0.01, 0.02, 0.03, 0.04, 0.05]:
            serv.add_task({
                "runner": "Poisson3D_CGRunner",
                "timeexp": True,
                "mesh_len": mesh_len,
            })

    
    main_experiment()
    # ablation_experiment()
    # time_mesh_experiment()
    # main_experiment_timed()

    serv.run()