import os
import sys
os.environ["DDEBACKEND"] = "pytorch"

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--case", type=str, default="wave")
casename = parser.parse_args().case

sys.path.append('../utils')
from parallel import Server
from cases import Wave1D, Poisson1D
import numpy as np
import torch
from callbacks import TesterCallback

def wave_cond_subprogress(gpu, path, taskid, repeatid, C):
    import deepxde as dde
    dde.config.set_default_float('float32')
    
    pde = Wave1D(C=C)
    net = dde.nn.FNN([2] + 5*[100] + [1], "tanh", "Glorot normal")
    net.apply_output_transform(pde.output_transform)
    net = net.float()
    opt = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.99, 0.99))

    model = pde.model(net, src=False)
    model.compile(opt)

    pdelosses = model.data.losses
    inp = torch.tensor(pde.test_x, dtype=torch.float32, requires_grad=True)
    cond_multiplier = np.square(pde.src_term_numpy(pde.test_x)).mean() / np.square(pde.test_y).mean()
    def losses_wrap(targets, outputs, loss_fn, inputs, model, aux=None):
        du = model.net(inp)
        du_norm = torch.square(du).mean()
        df = pde.pde(inp, du)
        df_norm = torch.square(df).mean()
        cond = du_norm / df_norm * cond_multiplier
        return [1/cond]
    model.data.losses = losses_wrap

    if taskid != 0:
        model.restore(os.path.join(path, "..", f"{taskid-1}-{repeatid}", "-20000.pt"), verbose=1)
    model.train(iterations=20000, display_every=100, model_save_path=path)

def poisson_cond_subprogress(gpu, path, taskid, repeatid, P):
    import deepxde as dde
    dde.config.set_default_float('float32')
    
    pde = Poisson1D(P=P)
    net = dde.nn.FNN([1] + 5*[100] + [1], "tanh", "Glorot normal")
    net.apply_output_transform(pde.output_transform)
    net = net.float()
    opt = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.99, 0.99))

    model = pde.model(net, src=False)
    model.compile(opt)

    inp = torch.tensor(pde.test_x, dtype=torch.float32, requires_grad=True)
    def losses_wrap(targets, outputs, loss_fn, inputs, model, aux=None):
        du = model.net(inp)
        du_norm = torch.square(du).mean()
        df = pde.pde(inp, du)
        df_norm = torch.square(df).mean()
        # cond = du_norm / df_norm
        condinv = df_norm / du_norm
        return [condinv]
    model.data.losses = losses_wrap

    losshistory, _ = model.train(iterations=5000, display_every=100, model_save_path=path)
    np.savetxt(os.path.join(path, "train_loss.txt"), losshistory.loss_train)
    np.savetxt(os.path.join(path, "test_loss.txt"), losshistory.loss_test)



def wave_cond_calc(serv):
    C_list = np.arange(1.1, 5.05, 0.1)
    for C in C_list:
        serv.add_task({'target':wave_cond_subprogress, 'args':{'C': C}})

def poisson_cond_calc(serv):
    P_list = np.linspace(1, 5, num=41)
    for P in P_list:
        serv.add_task({'target':poisson_cond_subprogress, 'args':{'P': P}})

if __name__ == "__main__":
    serv = Server(exp_name="Poisson_NNCond", repeat=1)

    if casename == 'wave':
        wave_cond_calc(serv)
    elif casename == 'poisson':
        poisson_cond_calc(serv)
    else:
        raise ValueError('Unknown Case Name' + casename)
    
    serv.run()