import random
import torch
import time
import datetime
import subprocess
import os
import argparse
import gc
import sys
import pickle
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.sparse import diags
from pprint import pprint


import matplotlib.pyplot as plt
from net.network import *


# ARGS
parser = argparse.ArgumentParser("SEM")
## Data
parser.add_argument("--equation", type=str, default='Standard', choices=['Standard', 'varcoeff', 'burgers','bdrylayer'])
parser.add_argument("--eps", type=float, default=1)
parser.add_argument("--b", type=float, default=-1)
parser.add_argument("--file", type=str, default='3000N32', help='Example: --file 2000N31')
parser.add_argument("--cut_train_data", type=int, default=3000, help='Number of training data')
parser.add_argument("--basis_order", type=int, default=1, help='P1->d=1, P2->d=2')
parser.add_argument("--bdry", type=str, default='dirichlet', choices=['dirichlet', 'neumann'])

## Train parameters
parser.add_argument("--pretrained", type=str, default=None)
parser.add_argument("--model", type=str, default='NetA', choices=['NetA','Net2D','DeepONet'])
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--depth_trunk", type=int, default=2)
parser.add_argument("--width_trunk", type=int, default=30)
parser.add_argument("--depth_branch", type=int, default=2)
parser.add_argument("--width_branch", type=int, default=30)
parser.add_argument("--act", type=str, default=None)
parser.add_argument("--loss", type=str, default='MSE', choices=['MAE', 'MSE', 'RMSE', 'RelMSE'])
parser.add_argument("--epochs", type=int, default=80000)
parser.add_argument("--pre_epochs", type=int, default=0)
parser.add_argument("--lr", type=float, default=1e-4)

args = parser.parse_args()
gparams = args.__dict__

#Equation
EQUATION = gparams['equation']
EPS = gparams['eps']
b = gparams['b']
FILE = gparams['file']
CUT_TRAIN_DATA = gparams['cut_train_data']
NUM_DATA = int(FILE.split('N')[0])
BASIS_ORDER = gparams['basis_order']
BDRY = gparams['bdry']


mesh=np.load('../../mesh_1DP{}/ne{}.npz'.format(BASIS_ORDER,int(FILE.split('N')[1].split('_')[0])))
NUM_ELEMENT, NUM_PTS, p, c = mesh['ne'], mesh['ng'], mesh['p'], mesh['c']
NUM_BASIS = NUM_PTS
p=torch.FloatTensor(p).cuda().reshape(-1,1)

if NUM_ELEMENT!=int(FILE.split('N')[1]):
    print("Error!! : Please check --file with --num_data and --N")

#Model
models = {
          'NetA': NetA,
          'Net2D': Net2D,
          'DeepONet': DeepONet
          }
MODEL = models[gparams['model']]
d_t=gparams['depth_trunk']
w_t=gparams['width_trunk']
d_b=gparams['depth_branch']
w_b=gparams['width_branch']
act=gparams['act']

#Train
EPOCHS = int(gparams['epochs'])
pre_EPOCHS = int(gparams['pre_epochs'])
LR=gparams['lr']
LOSS=gparams['loss']
D_in = 1
D_out = NUM_BASIS
if gparams['batch_size']==None:#Full-batch
    BATCH_SIZE_train = CUT_TRAIN_DATA
    BATCH_SIZE_validate = NUM_DATA 
elif gparams['batch_size']<CUT_TRAIN_DATA:
    BATCH_SIZE_train = gparams['batch_size']
    BATCH_SIZE_validate = gparams['batch_size']
else:
    BATCH_SIZE_train = CUT_TRAIN_DATA
    BATCH_SIZE_validate = gparams['batch_size']
    
#Save file
cur_time = str(datetime.datetime.now()).replace(' ', 'T')
cur_time = cur_time.replace(':','').split('.')[0].replace('-','')
FOLDER = f'{gparams["model"]}_epochs{EPOCHS}_{cur_time}'
PATH = os.path.join('../../train', 'deeponet', FILE, FOLDER)





# CREATE PATHING
if os.path.isdir(PATH) == False: os.makedirs(PATH); os.makedirs(os.path.join(PATH, 'pics'))
elif os.path.isdir(PATH) == True:
    if args.pretrained is None:
        print("\n\nPATH ALREADY EXISTS!\n\nEXITING\n\n")
        exit()
    else:
        print("\n\nPATH ALREADY EXISTS!\n\nLOADING MODEL\n\n")

model_FEM=MODEL(d_t, w_t, d_b, w_b, act, NUM_PTS, 1, 1)

# SEND TO GPU (or CPU)
model_FEM.cuda()
    

# KAIMING INITIALIZATION
def weights_init(m):
    if isinstance(m, nn.Conv1d):
        # torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.kaiming_normal_(m.weight.data)
        torch.nn.init.zeros_(m.bias)

model_FEM.apply(weights_init)


    
class Dataset(Dataset):
    def __init__(self, gparams, mesh, kind='train'):
        self.kind=kind
        self.pickle_file = gparams['file']
        with open(f'../../data/P{BASIS_ORDER}/{kind}/' + self.pickle_file + '.pkl', 'rb') as f:
            self.data = pickle.load(f)
        self.load_vector = mesh[f'{kind}_load_vectors']
    def __getitem__(self, idx):
        coeff_u = torch.FloatTensor(self.data[idx,0]).unsqueeze(0)
        f_value = torch.FloatTensor(self.data[idx,1]).unsqueeze(0)
        coeff_f = torch.FloatTensor(self.data[idx,2])
        load_vec_f = torch.FloatTensor(self.load_vector[idx])
        return {'coeff_u': coeff_u, 'f_value': f_value, 'coeff_f': coeff_f, 'load_vec_f' : load_vec_f}

    def __len__(self):
        if self.kind == 'train':
            return len(self.data[:CUT_TRAIN_DATA])
        else:
            return len(self.data)

lg_dataset = Dataset(gparams, mesh, kind='train')
trainloader = DataLoader(lg_dataset, batch_size=BATCH_SIZE_train, shuffle=True)
lg_dataset = Dataset(gparams, mesh, kind='validate')
validateloader = DataLoader(lg_dataset, batch_size=BATCH_SIZE_validate, shuffle=False)



optimizer = torch.optim.Adam(params=model_FEM.parameters(), lr=LR, weight_decay=1e-5)


loss_func = torch.nn.MSELoss()




def closure(model, f_value, coeff_u):
    pred_coeff_u = model(p, f_value)
    loss = loss_func(pred_coeff_u, coeff_u)
    return  loss, pred_coeff_u

def rel_L2_error(pred, true):
    return (torch.sum((true-pred)**2, dim=-1)/torch.sum((true)**2, dim=-1))**0.5

def log_gparams(gparams):
    cwd = os.getcwd()
    os.chdir(PATH)
    with open('parameters.txt', 'w') as f:
        for k, v in gparams.items():
            if k == 'losses':
                df = pd.DataFrame(gparams['losses'])
                df.to_csv('losses.csv')
            else:
                entry = f"{k}:{v}\n"
                f.write(entry)
    os.chdir(cwd)


def log_path(path):
    with open("../../paths.txt", "a") as f:
        f.write(str(path) + '\n')
        f.close()
log_path(PATH)
log_gparams(gparams)
################################################
time0 = time.time()
losses=[]
train_rel_L2_errors=[]
test_rel_L2_errors=[]
for epoch in range(1, EPOCHS+1):
    model_FEM.train()
    loss_total = 0
    num_samples=0
    train_rel_L2_error = 0

    for batch_idx, sample_batch in enumerate(trainloader):
        optimizer.zero_grad()
        coeff_u = sample_batch['coeff_u'].cuda()
        f_value = sample_batch['f_value'].cuda()
        
        loss,u_pred = closure(model_FEM, f_value, coeff_u)

        loss.backward()  

        optimizer.step(loss.item)
        loss_total += np.round(float(loss.item()), 4)
        num_samples += coeff_u.shape[0]

        with torch.no_grad():
            model_FEM.eval()
            _,u_pred = closure(model_FEM, f_value, coeff_u)
            u_pred=u_pred.squeeze().detach().cpu()
            coeff_u=coeff_u.squeeze().detach().cpu()
            train_rel_L2_error += torch.sum(rel_L2_error(u_pred, coeff_u))

    train_rel_L2_error /= num_samples
    

    if epoch%500==0:
        ## Test
        num_samples=0
        test_rel_L2_error = 0
        for batch_idx, sample_batch in enumerate(validateloader):
            with torch.no_grad():
                model_FEM.eval()
                coeff_u = sample_batch['coeff_u'].cuda()
                f_value = sample_batch['f_value'].cuda()
                _,u_pred = closure(model_FEM, f_value, coeff_u)
                u_pred=u_pred.squeeze().detach().cpu()
                coeff_u=coeff_u.squeeze().detach().cpu()
                test_rel_L2_error += torch.sum(rel_L2_error(u_pred, coeff_u))

                num_samples += coeff_u.shape[0]
        test_rel_L2_error /= num_samples
        
        ##Save and print
        losses.append(loss_total)
        train_rel_L2_errors.append(train_rel_L2_error)
        test_rel_L2_errors.append(test_rel_L2_error)
        torch.save({'model_state_dict': model_FEM.state_dict(),
                    'losses': losses,
                    'train_rel_L2_errors': train_rel_L2_errors,
                    'test_rel_L2_errors': test_rel_L2_errors
        }, PATH + '/model.pt')
        print("Epoch {0:4d}: weak_form_loss {1:4.1f}, train_rel_error {2:.5f}, test_rel_error {3:.5f}".format(epoch, loss_total, train_rel_L2_error, test_rel_L2_error))

        
torch.save({'model_state_dict': model_FEM.state_dict(),
            'losses': losses,
            'train_rel_L2_errors': train_rel_L2_errors,
            'test_rel_L2_errors': test_rel_L2_errors
}, PATH + '/model.pt')
        
train_t=time.time()-time0
NPARAMS = sum(p.numel() for p in model_FEM.parameters() if p.requires_grad)

gparams['train_time'] = train_t
gparams['nParams'] = NPARAMS
gparams['batchSize_train'] = BATCH_SIZE_train
gparams['batchSize_validate'] = BATCH_SIZE_validate
gparams['path'] = PATH

log_gparams(gparams)
