import random
import torch
import time
import datetime
import subprocess
import os
import argparse
import gc
import sys
import pickle
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.sparse import diags
from pprint import pprint


import matplotlib.pyplot as plt
from net.network import *


# ARGS
parser = argparse.ArgumentParser("SEM")
## Data
parser.add_argument("--equation", type=str, default='Standard', choices=['Standard', 'varcoeff', 'burgers','bdrylayer'])
parser.add_argument("--eps", type=float, default=0.1)
parser.add_argument("--b", type=float, default=-1)
parser.add_argument("--file", type=str, default='3000N32', help='Example: --file 3000N18_dirichlet')
parser.add_argument("--basis_order", type=int, default=1, help='P1->d=1, P2->d=2')
parser.add_argument("--bdry", type=str, default='dirichlet', choices=['dirichlet', 'neumann'])

## Train parameters
parser.add_argument("--pretrained", type=str, default=None)
parser.add_argument("--model", type=str, default='Net2D', choices=['NetA','Net2D','DeepONet'])
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--blocks", type=int, default=0)
parser.add_argument("--ks", type=int, default=5)
parser.add_argument("--filters", type=int, default=32, choices=[8, 16, 32, 64])
parser.add_argument("--loss", type=str, default='MSE', choices=['MAE', 'MSE', 'RMSE', 'RelMSE'])
parser.add_argument("--epochs", type=int, default=80000)
parser.add_argument("--pre_epochs", type=int, default=0)

args = parser.parse_args()
gparams = args.__dict__

#Equation
EQUATION = gparams['equation']
EPS = gparams['eps']
b = gparams['b']
FILE = gparams['file']
NUM_DATA = int(FILE.split('N')[0])
BASIS_ORDER = gparams['basis_order']
BDRY = gparams['bdry']

mesh=np.load('mesh_1DP{}/ne{}.npz'.format(BASIS_ORDER,int(FILE.split('N')[1])))
NUM_ELEMENT, NUM_PTS, p, c = mesh['ne'], mesh['ng'], mesh['p'], mesh['c']
NUM_BASIS = NUM_PTS

STIFF=mesh['stiff']
CONV=mesh['convection']

if NUM_ELEMENT!=int(FILE.split('N')[1]):
    print("Error!! : Please check --file with --num_data and --N")

#Model
models = {'NetA': NetA,
          'Net2D': Net2D,
          }
MODEL = models[gparams['model']]
BLOCKS = int(gparams['blocks'])
KERNEL_SIZE = int(gparams['ks'])
FILTERS = int(gparams['filters'])
PADDING = (KERNEL_SIZE - 1)//2

#Train
EPOCHS = int(gparams['epochs'])
pre_EPOCHS = int(gparams['pre_epochs'])
LOSS=gparams['loss']
D_in = 1
D_out = NUM_BASIS
if gparams['batch_size']==None:
    BATCH_SIZE = NUM_DATA #Full-batch
else:
    BATCH_SIZE = gparams['batch_size']
    
#Save file
cur_time = str(datetime.datetime.now()).replace(' ', 'T')
cur_time = cur_time.replace(':','').split('.')[0].replace('-','')
FOLDER = f'{gparams["model"]}_epochs{EPOCHS}_{cur_time}'
PATH = os.path.join('train', f'P{BASIS_ORDER}', FILE, FOLDER)





# CREATE PATHING
if os.path.isdir(PATH) == False: os.makedirs(PATH); os.makedirs(os.path.join(PATH, 'pics'))
elif os.path.isdir(PATH) == True:
    if args.pretrained is None:
        print("\n\nPATH ALREADY EXISTS!\n\nEXITING\n\n")
        exit()
    else:
        print("\n\nPATH ALREADY EXISTS!\n\nLOADING MODEL\n\n")

model_FEM = MODEL(D_in, FILTERS, D_out, kernel_size=KERNEL_SIZE, padding=PADDING, blocks=BLOCKS)

# SEND TO GPU (or CPU)
model_FEM.cuda()
    

# KAIMING INITIALIZATION
def weights_init(m):
    if isinstance(m, nn.Conv1d):
        # torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.kaiming_normal_(m.weight.data)
        torch.nn.init.zeros_(m.bias)

model_FEM.apply(weights_init)


    
class Dataset(Dataset):
    def __init__(self, gparams, mesh, kind='train'):
        self.pickle_file = gparams['file']
        with open(f'data/P{BASIS_ORDER}/{kind}/' + self.pickle_file + '.pkl', 'rb') as f:
            self.data = pickle.load(f)
        self.load_vector = mesh[f'{kind}_load_vectors']
    def __getitem__(self, idx):
        coeff_u = torch.FloatTensor(self.data[idx,0]).unsqueeze(0)
        f_value = torch.FloatTensor(self.data[idx,1]).unsqueeze(0)
        coeff_f = torch.FloatTensor(self.data[idx,2])
        load_vec_f = torch.FloatTensor(self.load_vector[idx])
        return {'coeff_u': coeff_u, 'f_value': f_value, 'coeff_f': coeff_f, 'load_vec_f' : load_vec_f}

    def __len__(self):
        return len(self.data)

lg_dataset = Dataset(gparams, mesh, kind='train')
trainloader = DataLoader(lg_dataset, batch_size=BATCH_SIZE, shuffle=True)
lg_dataset = Dataset(gparams, mesh, kind='validate')
validateloader = DataLoader(lg_dataset, batch_size=BATCH_SIZE, shuffle=False)

def init_optim(model):
    params = {'history_size': 10,
              'tolerance_grad': 1E-15,
              'tolerance_change': 1E-15,
              'max_eval': 10,
                }
    return torch.optim.LBFGS(model.parameters(), **params)

optimizer = init_optim(model_FEM)

    
STIFF, CONV = torch.tensor(STIFF).cuda().float(), torch.tensor(CONV).cuda().float()


criterion_wf = torch.nn.MSELoss(reduction="sum")

def weak_form(eps, coeff_u, load_vec_f, stiff, conv):
    # pts : N+1
    # num of basis : N+1
    # coeff_u : (num_f, 1, N+1)
    # 
    #return LHS, RHS : (num_f, N+1)
    coeff_u=coeff_u.repeat(1,coeff_u.shape[-1],1)
    
    ## LHS
    LHS = (eps*stiff+conv)*coeff_u
    LHS=torch.sum(LHS,dim=-1)
    
    ## RHS
    RHS = load_vec_f.cuda()#.float()
    return LHS, RHS



def closure(model, eps, f_value, load_vec_f, stiff, conv):
    pred_coeff_u = model(f_value)
    LHS, RHS = weak_form(eps, pred_coeff_u, load_vec_f, stiff, conv)
    
    ## Loss
    loss_wf=torch.zeros((NUM_BASIS,))
    for ii in range(NUM_BASIS):
        # criterion_wf => summation on basis functions
        loss_wf[ii]=criterion_wf(LHS[:,ii], RHS[:,ii])

    # torch.sum => summation on funcions f_i
    loss = torch.sum(loss_wf)
    
    return  loss, pred_coeff_u

def rel_L2_error(pred, true):
    return (torch.sum((true-pred)**2, dim=-1)/torch.sum((true)**2, dim=-1))**0.5

def log_gparams(gparams):
    cwd = os.getcwd()
    os.chdir(PATH)
    with open('parameters.txt', 'w') as f:
        for k, v in gparams.items():
            if k == 'losses':
                df = pd.DataFrame(gparams['losses'])
                df.to_csv('losses.csv')
            else:
                entry = f"{k}:{v}\n"
                f.write(entry)
    os.chdir(cwd)


def log_path(path):
    with open("../../paths.txt", "a") as f:
        f.write(str(path) + '\n')
        f.close()
log_path(PATH)
log_gparams(gparams)
################################################
time0 = time.time()
losses=[]
train_rel_L2_errors=[]
test_rel_L2_errors=[]
for epoch in range(1, EPOCHS+1):
    model_FEM.train()
    loss_total = 0
    num_samples=0
    train_rel_L2_error = 0

    for batch_idx, sample_batch in enumerate(trainloader):
        optimizer.zero_grad()
        coeff_u = sample_batch['coeff_u']
        f_value = sample_batch['f_value'].cuda()
        load_vec_f = sample_batch['load_vec_f'].cuda()
        loss,u_pred = closure(model_FEM, EPS, f_value, load_vec_f, STIFF, CONV)

        loss.backward()  

        optimizer.step(loss.item)
        loss_total += np.round(float(loss.item()), 4)
        num_samples += coeff_u.shape[0]

        with torch.no_grad():
            model_FEM.eval()
            _,u_pred = closure(model_FEM, EPS, f_value, load_vec_f, STIFF, CONV)
            u_pred=u_pred.squeeze().detach().cpu()
            coeff_u=coeff_u.squeeze()
            train_rel_L2_error += torch.sum(rel_L2_error(u_pred, coeff_u))

    train_rel_L2_error /= num_samples
    

    if epoch%500==0:
        ## Test
        num_samples=0
        test_rel_L2_error = 0
        for batch_idx, sample_batch in enumerate(validateloader):
            with torch.no_grad():
                model_FEM.eval()
                coeff_u = sample_batch['coeff_u']
                f_value = sample_batch['f_value'].cuda()
                load_vec_f = sample_batch['load_vec_f'].cuda()
                _,u_pred = closure(model_FEM, EPS, f_value, load_vec_f, STIFF, CONV)
                u_pred=u_pred.squeeze().detach().cpu()
                coeff_u=coeff_u.squeeze()
                test_rel_L2_error += torch.sum(rel_L2_error(u_pred, coeff_u))

                num_samples += coeff_u.shape[0]
        test_rel_L2_error /= num_samples
        
        ##Save and print
        losses.append(loss_total)
        train_rel_L2_errors.append(train_rel_L2_error)
        test_rel_L2_errors.append(test_rel_L2_error)
        torch.save({'model_state_dict': model_FEM.state_dict(),
                    'losses': losses,
                    'train_rel_L2_errors': train_rel_L2_errors,
                    'test_rel_L2_errors': test_rel_L2_errors
        }, PATH + '/model.pt')
        print("Epoch {0:4d}: weak_form_loss {1:4.1f}, train_rel_error {2:.5f}, test_rel_error {3:.5f}".format(epoch, loss_total, train_rel_L2_error, test_rel_L2_error))

        
torch.save({'model_state_dict': model_FEM.state_dict(),
            'losses': losses,
            'train_rel_L2_errors': train_rel_L2_errors,
            'test_rel_L2_errors': test_rel_L2_errors
}, PATH + '/model.pt')
        
train_t=time.time()-time0
NPARAMS = sum(p.numel() for p in model_FEM.parameters() if p.requires_grad)

gparams['train_time'] = train_t
gparams['nParams'] = NPARAMS
gparams['batchSize'] = BATCH_SIZE
gparams['path'] = PATH

log_gparams(gparams)
