import os
os.environ["DDEBACKEND"] = "pytorch"

import numpy as np
import matplotlib.pyplot as plt

from cases import Burgers1D, Poisson1D, Wave1D, Helmholtz2d

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--case", type=str, default="wave")
casename = parser.parse_args().case

def burger_cond_calc():
    nu_list = np.logspace(-2, 0, 21, base=10) / np.pi
    conds = []

    for i, nu in enumerate(nu_list):
        print("Now At Step", i)
        pde = Burgers1D(nu=nu)
        conds.append(pde.cond())
    
    with open("results/burger_thycond.txt", "w") as f:
        print(nu_list.tolist(), file=f)
        print(conds, file=f)
    
    plt.figure()
    plt.plot(np.arange(nu_list.shape[0]), conds)
    plt.savefig("results/burger_thycond.png")

def poisson_cond_calc():
    result = []
    for mesh in [2, 4, 8, 16, 32, 64]:
        P_list = np.linspace(1, 5, num=41)
        conds = []

        for i, P in enumerate(P_list):
            pde = Poisson1D(mesh=mesh, P=P)
            conds.append(pde.cond().item())
        
        result.append(conds)
        
    np.savetxt("results/poisson_thycond.txt", result)

def wave_cond_calc():
    C_list = np.linspace(1.1, 5, 100)
    conds = []
    for i, C in enumerate(C_list):
        print(f"Now At step {i}")
        pde = Wave1D(mesh=(50, 50), C=C)
        conds.append(pde.cond())
    
    with open("results/wave_thycond.txt", "w") as f:
        print(C_list.tolist(), file=f)
        print(conds, file=f)
    
    plt.figure()
    plt.yscale('log')
    plt.plot(C_list, conds)
    plt.savefig("results/wave_thycond.png")

def wave_cond_check():
    C_list = np.linspace(1.1, 5, 100)
    err = []
    for C in C_list:
        pde = Wave1D(mesh=(50, 50), C=C)
        err.append(pde.check_A())
    
    err = np.array(err)
    plt.figure()
    plt.yscale('log')
    plt.plot(C_list, err[:, 0], label="relerr")
    plt.plot(C_list, err[:, 1], label="unorm")
    plt.legend()
    plt.savefig("results/wave_err.png")

def helmholtz_cond_calc():
    A_list = np.arange(1, 20, 1)
    conds = []
    for i, A in enumerate(A_list):
        print(f"Now At step {i}")
        pde = Helmholtz2d(A=(A, A))
        conds.append(pde.cond())
    
    with open("results/helm_thycond.txt", "w") as f:
        print(A_list.tolist(), file=f)
        print(conds, file=f)

    plt.figure()
    plt.yscale('log')
    plt.plot(A_list, conds)
    plt.savefig("results/helm_thycond.png")

if __name__ == "__main__":
    if casename == 'burger':
        burger_cond_calc()
    elif casename == 'poisson':
        poisson_cond_calc()
    elif casename == 'wave':
        wave_cond_calc()
    elif casename == 'helmholtz':
        helmholtz_cond_calc()
    else:
        raise ValueError("Unknown Case Name" + casename)