import argparse
import numpy as np
import os
import random
import torch
from systems import *
import torch.backends.cudnn as cudnn
from utils import *
import matplotlib.pyplot as plt
import pdb
import pandas as pd
import pickle as pkl
from mpl_toolkits.axes_grid1 import make_axes_locatable

parser = argparse.ArgumentParser()

parser.add_argument('--system', type=str, default='convection')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--N_f', type=int, default=1000, help='Number of collocation points to sample.')
parser.add_argument('--L', type=float, default=1.0, help='Multiplier on loss f.')

parser.add_argument('--xgrid', type=int, default=256, help='Number of points in the xgrid.')
parser.add_argument('--nt', type=int, default=100, help='Number of points in the tgrid.')
parser.add_argument('--nu', type=float, default=1.0)
parser.add_argument('--rho', type=float, default=1.0)
parser.add_argument('--beta', type=float, default=1.0)
parser.add_argument('--u0_str', default='1+sin(x)')
parser.add_argument('--source', default=0, type=float)


args = parser.parse_args()

# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

rho=0
nu=0
beta=0

x = np.linspace(0, 2*np.pi, args.xgrid, endpoint=False).reshape(-1, 1) 
t = np.linspace(0, 1, args.nt).reshape(-1, 1)
X, T = np.meshgrid(x, t) 
X_star = np.hstack((X.flatten()[:, None], T.flatten()[:, None])) 

with open('./convection/convection_grid.pkl', 'wb') as f:
    pkl.dump(X_star, f)

u_total_vals = []
for beta in range(3,123,2):
    beta = beta * 0.5
    print('nu', nu, 'beta', beta, 'rho', rho)
    

    if 'convection' in args.system or 'diffusion' in args.system:
        u_vals, u_v = convection_diffusion_discrete_solution(args.u0_str, nu, beta, args.source, args.xgrid, args.nt)
    elif 'rd' in args.system:
        u_vals = reaction_diffusion_discrete_solution(args.u0_str, nu, rho, args.xgrid, args.nt)
    elif 'reaction' in args.system:
        u_vals = reaction_solution(args.u0_str, rho, args.xgrid, args.nt)
    else:
        print("WARNING: System is not specified.")


    Exact = u_v.flatten()[:,None] 
    u_total_vals.append(Exact)

u_total_vals = np.array(u_total_vals)
    

with open('./convection/convection_beta1_5to60_5_interval1.pkl', 'wb') as f:
    pkl.dump(u_total_vals, f)
    
