import os
import sys
os.environ["DDEBACKEND"] = "pytorch"

import numpy as np
import multiprocessing as mp
import time
import torch

from cases import Burgers1D, Poisson1D, Wave1D, Helmholtz2d
from callbacks import TesterCallback

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--case", type=str, default="wave")
parser.add_argument("--gpus", type=str, default="0")
command_args = parser.parse_args()
casename = command_args.case
gpus = list(map((lambda s:int(s.strip())), command_args.gpus.split(',')))

class HookedStdout:

    def __init__(self, filename, stdout=None) -> None:
        self.stdout = stdout
        self.file = open(filename, 'w')

    def write(self, data):
        self.stdout.write(data)
        self.file.write(data)

    def flush(self):
        self.stdout.flush()
        self.file.flush()

def burger_subprocess(save_path, gpu, nu, load_path=None):
    import sys
    sys.stdout = HookedStdout(os.path.join(save_path, "log.txt"), sys.stdout)
    sys.stderr = HookedStdout(os.path.join(save_path, "err.txt"), sys.stderr)

    import torch
    import deepxde as dde
    from callbacks import TesterCallback
    torch.cuda.set_device("cuda:"+str(gpu))
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    dde.config.set_default_float('float32')

    pde = Burgers1D(nu=nu)
    net = dde.nn.FNN([2] + 5*[100] + [1], "tanh", "Glorot normal")
    net.apply_output_transform(pde.output_transform)
    net = net.float()
    opt = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.99, 0.99))

    model = pde.model(net)
    model.compile(opt)
    if load_path: model.restore(os.path.join(load_path, "-20000.pt"), verbose=1)
    model.train(iterations=20000, display_every=100, model_save_path=save_path, 
                callbacks=[TesterCallback(test_data=(pde.ref_data[:, :2], pde.ref_data[:, 2:]), save_path=save_path, log_every=100)])

def burger_experiment(name, devices):
    mp.set_start_method('spawn')
    path = os.path.join("runs", name)
    os.mkdir(path)
    process = [None] * len(devices)

    nu_list = np.logspace(-2, 0, 21, base=10) / np.pi
    for i, nu in enumerate(nu_list):
        for j, gpu in enumerate(devices):
            while process[j] is not None and process[j].is_alive():
                time.sleep(5)
            save_path = os.path.join(path, f"{i}-{j}") + '/'
            os.mkdir(save_path)
            load_path = os.path.join(path, f"{i-1}-{j}") if i else None
            process[j] = mp.Process(target=burger_subprocess, args=(save_path, gpu, nu, load_path))
            process[j].start()
            time.sleep(10)

def wave_subprocess(save_path, gpu, C, load_path=None):
    import sys
    sys.stdout = HookedStdout(os.path.join(save_path, "log.txt"), sys.stdout)
    sys.stderr = HookedStdout(os.path.join(save_path, "err.txt"), sys.stderr)

    import torch
    import deepxde as dde
    from callbacks import TesterCallback
    torch.cuda.set_device("cuda:"+str(gpu))
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    dde.config.set_default_float('float32')
    
    pde = Wave1D(C=C)
    net = dde.nn.FNN([2] + 5*[100] + [1], "tanh", "Glorot normal")
    net.apply_output_transform(pde.output_transform)
    net = net.float()
    opt = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.99, 0.99))

    model = pde.model(net)
    model.compile(opt)
    if load_path: model.restore(os.path.join(load_path, "-20000.pt"), verbose=1)
    model.train(iterations=20000, display_every=100, model_save_path=save_path, 
                callbacks=[TesterCallback(test_data=(pde.test_x, pde.test_y), save_path=save_path, log_every=100)])

def wave_experiment(name, devices, transfer=True):
    mp.set_start_method('spawn')
    path = os.path.join("runs", name)
    os.mkdir(path)
    process = [None] * len(devices)

    C_list = np.arange(1.1, 5.05, 0.1)
    for i, C in enumerate(C_list):
        for j, gpu in enumerate(devices):
            while process[j] is not None and process[j].is_alive():
                time.sleep(5)
            save_path = os.path.join(path, f"{i}-{j}") + '/'
            os.mkdir(save_path)
            load_path = os.path.join(path, f"{i-1}-{j}") if i and transfer else None
            process[j] = mp.Process(target=wave_subprocess, args=(save_path, gpu, C, load_path))
            process[j].start()
            time.sleep(10)

def helm_subprocess(save_path, gpu, A, load_path=None):
    import sys
    sys.stdout = HookedStdout(os.path.join(save_path, "log.txt"), sys.stdout)
    sys.stderr = HookedStdout(os.path.join(save_path, "err.txt"), sys.stderr)

    import torch
    import deepxde as dde
    from callbacks import TesterCallback
    torch.cuda.set_device("cuda:"+str(gpu))
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    dde.config.set_default_float('float32')
    
    pde = Helmholtz2d(A=(A, A))
    net = dde.nn.FNN([2] + 5*[100] + [1], "tanh", "Glorot normal")
    net.apply_output_transform(pde.output_transform)
    net = net.float()
    opt = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.99, 0.99))

    model = pde.model(net)
    model.compile(opt)
    if load_path: model.restore(os.path.join(load_path, "-20000.pt"), verbose=1)
    model.train(iterations=20000, display_every=100, model_save_path=save_path, 
                callbacks=[TesterCallback(test_data=(pde.test_x, pde.test_y), save_path=save_path, log_every=100)])

def helm_experiment(name, devices):
    mp.set_start_method('spawn')
    path = os.path.join("runs", name)
    os.mkdir(path)
    process = [None] * len(devices)

    A_list = np.arange(1, 20, 1)[-3:-2]
    for i, A in enumerate(A_list):
        for j, gpu in enumerate(devices):
            while process[j] is not None and process[j].is_alive():
                time.sleep(5)
            save_path = os.path.join(path, f"{i}-{j}") + '/'
            os.mkdir(save_path)
            load_path = os.path.join(path, f"{i-1}-{j}") if i else None
            process[j] = mp.Process(target=helm_subprocess, args=(save_path, gpu, A, load_path))
            process[j].start()

def wave_cond_subprocess(path, gpu, C, load_path=None):
    import sys
    sys.stdout = HookedStdout(os.path.join(path, "log.txt"), sys.stdout)
    sys.stderr = HookedStdout(os.path.join(path, "err.txt"), sys.stderr)

    import torch
    import deepxde as dde
    from callbacks import TesterCallback
    torch.cuda.set_device("cuda:"+str(gpu))
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    dde.config.set_default_float('float32')
    
    pde = Wave1D(C=C)
    net = dde.nn.FNN([2] + 5*[100] + [1], "tanh", "Glorot normal")
    net.apply_output_transform(pde.output_transform)
    net = net.float()
    opt = torch.optim.Adam(net.parameters(), lr=1e-4, betas=(0.99, 0.99))

    model = pde.model(net, src=False)
    model.compile(opt)

    inp = torch.tensor(pde.test_x, dtype=torch.float32, requires_grad=True)
    cond_multiplier = np.square(pde.src_term_numpy(pde.test_x)).mean() / np.square(pde.test_y).mean()
    def losses_wrap(targets, outputs, loss_fn, inputs, model, aux=None):
        du = model.net(inp)
        du_norm = torch.square(du).mean()
        df = pde.pde(inp, du)
        df_norm = torch.square(df).mean()
        # cond = du_norm / df_norm * cond_multiplier
        condinv = df_norm / du_norm / cond_multiplier
        return [condinv]
    model.data.losses = losses_wrap

    if load_path:
        model.restore(os.path.join(load_path, "-20000.pt"), verbose=1)
    model.train(iterations=20000, display_every=100, model_save_path=path)

def wave_cond_experiment(name, devices):
    mp.set_start_method('spawn')
    path = os.path.join("runs", name)
    os.mkdir(path)
    process = [None] * len(devices)

    C_list = np.arange(1.1, 5.05, 0.1)[10:]
    for i, C in enumerate(C_list):
        for j, gpu in enumerate(devices):
            while process[j] is not None and process[j].is_alive():
                time.sleep(5)
            save_path = os.path.join(path, f"{i}-{j}") + '/'
            os.mkdir(save_path)
            load_path = os.path.join(path, f"{i-1}-{j}") if i else None
            process[j] = mp.Process(target=wave_cond_subprocess, args=(save_path, gpu, C, load_path))
            process[j].start()

def poisson_cond_subprocess(gpu, path, taskid, repeatid, P):
    import deepxde as dde
    dde.config.set_default_float('float32')
    
    pde = Poisson1D(P=P)
    net = dde.nn.FNN([1] + 5*[100] + [1], "tanh", "Glorot normal")
    net.apply_output_transform(pde.output_transform)
    net = net.float()
    opt = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.99, 0.99))

    model = pde.model(net, src=False)
    model.compile(opt)

    pdelosses = model.data.losses
    inp = torch.tensor(pde.test_x, dtype=torch.float32, requires_grad=True)
    cond_multiplier = pde.P**2 # source term is P**2 sin(Px) and reference is sinPx
    def losses_wrap(targets, outputs, loss_fn, inputs, model, aux=None):
        du = model.net(inp)
        du_norm = torch.square(du).mean()
        df = pde.pde(inp, du)
        df_norm = torch.square(df).mean()
        cond = du_norm / df_norm * cond_multiplier
        return [-cond]
    model.data.losses = losses_wrap

    model.train(iterations=20000, display_every=100, model_save_path=path)

def poisson_cond_experiment(serv):
    P_list = np.linspace(1, 5, num=41)
    for P in P_list:
        serv.add_task({'target':poisson_cond_subprocess, 'args':{'P': P}})

def wave_experiment_no_transfer(name="Wave_NoTransfer", path=None):
    sys.path.append("../utils")
    from parallel import Server

    def wave_cond_subprocess(C, **kwargs):

        import deepxde as dde
        from callbacks import TesterCallback
        dde.config.set_default_float('float32')
    
        pde = Wave1D(C=C)
        net = dde.nn.FNN([2] + 5*[100] + [1], "tanh", "Glorot normal")
        net.apply_output_transform(pde.output_transform)
        net = net.float()
        opt = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.99, 0.99))

        model = pde.model(net)
        model.compile(opt)
        model.train(iterations=20000, display_every=100, model_save_path=kwargs['path'], 
                    callbacks=[TesterCallback(test_data=(pde.test_x, pde.test_y), save_path=kwargs['path'], log_every=100)])

    serv = Server(exp_name=name, repeat=7, path=path)

    serv.set_default_task({"target": wave_cond_subprocess})
    C_list = np.arange(1.1, 5.05, 0.1)
    for i, C in enumerate(C_list):
        serv.add_task({"C": C})
    
    serv.run()

if __name__ == "__main__":
    if casename == "burger":
        burger_experiment(time.strftime('%m.%d-%H.%M.%S')+"Burger", gpus)
    elif casename == 'helmholtz':
        helm_experiment(time.strftime('%m.%d-%H.%M.%S')+"Helmholtz", gpus)
    elif casename == 'wave':
        wave_experiment(time.strftime('%m.%d-%H.%M.%S')+"Wave", gpus, transfer=True)
    else:
        raise ValueError("Unknown Case" + casename)
    # wave_cond_experiment(time.strftime('%m.%d-%H.%M.%S')+"Wave_NNCond", [0,1,2,3,4,5,6,7])