import torch
import os
import argparse
import pickle
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from scipy.sparse import diags
import scipy.interpolate
import scipy
from pprint import pprint

# ARGS
parser = argparse.ArgumentParser("SEM")
## Data
parser.add_argument("--equation", type=str, default='Standard', choices=['Standard', 'varcoeff', 'burgers'])
parser.add_argument("--eps", type=float, default=0.1)
parser.add_argument("--b", type=float, default=-1)
parser.add_argument("--num_data", type=int, default=3000)
parser.add_argument("--file", type=str, help='Example: --file 3000N18')
parser.add_argument("--basis_order", type=int, help='P1->d=1, P2->d=2')
parser.add_argument("--kind", type=str, default='train', choices=['train', 'validate'])
parser.add_argument("--bdry", type=str, default='dirichlet', choices=['dirichlet', 'neumann','bdry_layer'])
parser.add_argument("--exact", type=int, default=1000)


args = parser.parse_args()
gparams = args.__dict__

EQUATION = gparams['equation']
EPS = gparams['eps']
b = gparams['b']

NUM_DATA = gparams['num_data']
FILE = gparams['file']
BASIS_ORDER = gparams['basis_order']
KIND = gparams['kind']
BDRY = gparams['bdry']
NUM_ELEMENT_exact = gparams['exact']


# Seed
if KIND=='train':
    np.random.seed(5)
elif KIND=='validate':
    np.random.seed(10)
else:
    print('error!')

# Load exact data for interpolation
mesh=np.load('mesh_1DP{}/ne{}.npz'.format(BASIS_ORDER,NUM_ELEMENT_exact))
if NUM_ELEMENT_exact==mesh['ne']:
    p_exact = mesh['p']
    pickle_file = f'3000N'+str(NUM_ELEMENT_exact)
    with open(f'data/P{BASIS_ORDER}/{KIND}/' + pickle_file + '.pkl', 'rb') as f:
        data_exact = pickle.load(f)

mesh=np.load('mesh_1DP{}/ne{}.npz'.format(BASIS_ORDER,int(FILE.split('N')[1])))
NUM_ELEMENT, NUM_PTS, p, c, gfl = mesh['ne'], mesh['ng'], mesh['p'], mesh['c'], mesh['gfl']

NUM_BASIS = NUM_PTS

if NUM_DATA!=int(FILE.split('N')[0]) or NUM_ELEMENT!=int(FILE.split('N')[1]):
    print("Error!! : Please check --file with --num_data and --N")
    
def f(x,coeff):
    m0, m1, n0, n1=coeff
    return m0*np.sin(n0*x) + m1*np.cos(n1*x), coeff


def create_data(num_data, p, kind):
    data = []
    for n in tqdm(range(num_data)):
        f_value, coeff_f=f(p,mesh[f'{kind}_coeff_fs'][n])
        interp=scipy.interpolate.CubicSpline(p_exact[p_exact.argsort()],data_exact[n][0][p_exact.argsort()])
        coeff_u=interp(p)
        data.append([coeff_u, f_value, coeff_f])
    return np.array(data, dtype=object)

def save_obj(data, name, kind, basis_order):
    cwd = os.getcwd()
    path = os.path.join(cwd,'data', f'P{basis_order}', kind)
    if os.path.isdir(path) == False:
        os.makedirs(path)
    with open(path + '/' + name + '.pkl', 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

data = create_data(NUM_DATA, p, KIND)        
save_obj(data, FILE, KIND, BASIS_ORDER)
