import torch
import os
import argparse
import pickle
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from scipy.sparse import diags
from pprint import pprint


# ARGS
parser = argparse.ArgumentParser("SEM")
## Data
parser.add_argument("--equation", type=str, default='Standard', choices=['Standard', 'varcoeff', 'burgers'])
parser.add_argument("--eps", type=float, default=0.1)
parser.add_argument("--b", type=float, default=-1)
parser.add_argument("--num_data", type=int, default=3000)
parser.add_argument("--file", type=str, help='Example: --file 3000N18')
parser.add_argument("--basis_order", type=int, help='P1->d=1, P2->d=2')
parser.add_argument("--kind", type=str, default='train', choices=['train', 'validate'])
parser.add_argument("--bdry", type=str, default='dirichlet', choices=['dirichlet', 'neumann','bdry_layer'])

args = parser.parse_args()
gparams = args.__dict__

EQUATION = gparams['equation']
EPS = gparams['eps']
b = gparams['b']

NUM_DATA = gparams['num_data']
FILE = gparams['file']
BASIS_ORDER = gparams['basis_order']
KIND = gparams['kind']
BDRY = gparams['bdry']

# Seed
if KIND=='train':
    np.random.seed(5)
elif KIND=='validate':
    np.random.seed(10)
else:
    print('error!')

mesh=np.load('mesh_1DP{}/ne{}.npz'.format(BASIS_ORDER,int(FILE.split('N')[1])))
NUM_ELEMENT, NUM_PTS, p, c = mesh['ne'], mesh['ng'], mesh['p'], mesh['c']
NUM_BASIS = NUM_PTS

STIFF=mesh['stiff']
CONV=mesh['convection']

if NUM_DATA!=int(FILE.split('N')[0]) or NUM_ELEMENT!=int(FILE.split('N')[1]):
    print("Error!! : Please check --file with --num_data and --N")
    
def f(x,coeff):
    m0, m1, n0, n1=coeff
    return m0*np.sin(n0*x) + m1*np.cos(n1*x), coeff

def standard(n,eps, stiff, conv, kind):
    b=mesh[f'{kind}_load_vectors'][n]
    S=eps*stiff
    C=conv
    coeff_u=np.linalg.solve(S+C, b)
    return coeff_u

def create_data(num_data, p, eps, stiff, conv, kind):
    data = []
    for n in tqdm(range(num_data)):
        f_value, coeff_f=f(p,mesh[f'{kind}_coeff_fs'][n])
        coeff_u= standard(n, eps, stiff, conv, kind)
        data.append([coeff_u, f_value, coeff_f])
    return np.array(data, dtype=object)

def save_obj(data, name, kind, basis_order):
    cwd = os.getcwd()
    path = os.path.join(cwd,'data', f'P{basis_order}', kind)
    if os.path.isdir(path) == False:
        os.makedirs(path)
    with open(path + '/' + name + '.pkl', 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
         
data = create_data(NUM_DATA, p, EPS, STIFF, CONV, KIND)
save_obj(data, FILE, KIND, BASIS_ORDER)