import random
import torch
import time
import datetime
import subprocess
import os
import argparse
import gc
import sys
import pickle
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.sparse import diags
from pprint import pprint


import matplotlib.pyplot as plt
from net.network import *


# ARGS
parser = argparse.ArgumentParser("SEM")
parser.add_argument("--seed", type=int, default=0)
## Data
parser.add_argument("--equation", type=str, default='Standard', choices=['Standard', 'varcoeff', 'burgers','bdrylayer'])
parser.add_argument("--eps", type=float, default=1)
parser.add_argument("--b", type=float, default=-1)
parser.add_argument("--file", type=str, default='3000N32', help='Example: --file 2000N31') # 2^5-1, 2^6-1
parser.add_argument("--cut_train_data", type=int, default=3000, help='Number of training data')
parser.add_argument("--basis_order", type=int, default=1, help='P1->d=1, P2->d=2')
parser.add_argument("--bdry", type=str, default='dirichlet', choices=['dirichlet', 'neumann'])

## Train parameters
parser.add_argument("--pretrained", type=str, default=None)
parser.add_argument("--model", type=str, default='NetA', choices=['NetA','Net2D','DeepONet'])
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--depth_trunk", type=int, default=2)
parser.add_argument("--width_trunk", type=int, default=30)
parser.add_argument("--depth_branch", type=int, default=2)
parser.add_argument("--width_branch", type=int, default=30)
parser.add_argument("--act", type=str, default=None)
parser.add_argument("--loss", type=str, default='MSE', choices=['MAE', 'MSE', 'RMSE', 'RelMSE'])
parser.add_argument("--epochs", type=int, default=80000)
parser.add_argument("--pre_epochs", type=int, default=0)
parser.add_argument("--lr", type=float, default=1e-4)

args = parser.parse_args()
gparams = args.__dict__


# Setup seeds
random_seed=gparams['seed']
print(f"running with random seed : {random_seed}")
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)


#Equation
EQUATION = gparams['equation']
EPS = gparams['eps']
b = gparams['b']
FILE = gparams['file']
CUT_TRAIN_DATA = gparams['cut_train_data']
NUM_DATA = int(FILE.split('N')[0])
BASIS_ORDER = gparams['basis_order']
BDRY = gparams['bdry']

mesh=np.load('../../mesh_1DP{}/ne{}.npz'.format(BASIS_ORDER,int(FILE.split('N')[1].split('_')[0])))
NUM_ELEMENT, NUM_PTS, p, c = mesh['ne'], mesh['ng'], mesh['p'], mesh['c']
NUM_BASIS = NUM_PTS
p=torch.FloatTensor(p).cuda().reshape(-1,1)
bdry_p=p[mesh['gfl']==1].reshape(-1,1)

if NUM_ELEMENT!=int(FILE.split('N')[1]):
    print("Error!! : Please check --file with --num_data and --N")

#Model
models = {
          'NetA': NetA,
          'Net2D': Net2D,
          'DeepONet': DeepONet
          }
MODEL = models[gparams['model']]
d_t=gparams['depth_trunk']
w_t=gparams['width_trunk']
d_b=gparams['depth_branch']
w_b=gparams['width_branch']
act=gparams['act']

#Train
EPOCHS = int(gparams['epochs'])
pre_EPOCHS = int(gparams['pre_epochs'])
LR=gparams['lr']
LOSS=gparams['loss']
D_in = 1
D_out = NUM_BASIS
if gparams['batch_size']==None:#Full-batch
    BATCH_SIZE_train = CUT_TRAIN_DATA
    BATCH_SIZE_validate = NUM_DATA 
elif gparams['batch_size']<CUT_TRAIN_DATA:
    BATCH_SIZE_train = gparams['batch_size']
    BATCH_SIZE_validate = gparams['batch_size']
else:
    BATCH_SIZE_train = CUT_TRAIN_DATA
    BATCH_SIZE_validate = gparams['batch_size']

    
#Save file
cur_time = str(datetime.datetime.now()).replace(' ', 'T')
cur_time = cur_time.replace(':','').split('.')[0].replace('-','')
FOLDER = f'{gparams["model"]}_epochs{EPOCHS}_{cur_time}'
PATH = os.path.join('../../train', 'PIDeepONet', FILE, FOLDER)





# CREATE PATHING
if os.path.isdir(PATH) == False: os.makedirs(PATH); os.makedirs(os.path.join(PATH, 'pics'))
elif os.path.isdir(PATH) == True:
    if args.pretrained is None:
        print("\n\nPATH ALREADY EXISTS!\n\nEXITING\n\n")
        exit()
    else:
        print("\n\nPATH ALREADY EXISTS!\n\nLOADING MODEL\n\n")

model_FEM=MODEL(d_t, w_t, d_b, w_b, act, NUM_PTS, 1, 1)

# SEND TO GPU (or CPU)
model_FEM.cuda()
    

# KAIMING INITIALIZATION
def weights_init(m):
    if isinstance(m, nn.Conv1d):
        # torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.kaiming_normal_(m.weight.data)
        torch.nn.init.zeros_(m.bias)

model_FEM.apply(weights_init)


    
class Dataset(Dataset):
    def __init__(self, gparams, mesh, kind='train'):
        self.kind=kind
        self.pickle_file = gparams['file']
        with open(f'../../data/P{BASIS_ORDER}/{kind}/' + self.pickle_file + '.pkl', 'rb') as f:
            self.data = pickle.load(f)
        self.load_vector = mesh[f'{kind}_load_vectors']
    def __getitem__(self, idx):
        coeff_u = torch.FloatTensor(self.data[idx,0]).unsqueeze(0)
        f_value = torch.FloatTensor(self.data[idx,1]).unsqueeze(0)
        coeff_f = torch.FloatTensor(self.data[idx,2])
        load_vec_f = torch.FloatTensor(self.load_vector[idx])
        return {'coeff_u': coeff_u, 'f_value': f_value, 'coeff_f': coeff_f, 'load_vec_f' : load_vec_f}

    def __len__(self):
        if self.kind == 'train':
            return len(self.data[:CUT_TRAIN_DATA])
        else:
            return len(self.data)

lg_dataset = Dataset(gparams, mesh, kind='train')
trainloader = DataLoader(lg_dataset, batch_size=BATCH_SIZE_train, shuffle=True)
lg_dataset = Dataset(gparams, mesh, kind='validate')
validateloader = DataLoader(lg_dataset, batch_size=BATCH_SIZE_validate, shuffle=False)


optimizer = torch.optim.Adam(params=model_FEM.parameters(), lr=LR)


loss_func = torch.nn.MSELoss()




def closure(model, f_value):
    var_p = Variable(p, requires_grad=True)
    pred_coeff_u = model(var_p, f_value)
    loss_eq=0
    if model.training==True:
        for batch in range(f_value.shape[0]):
            d_u = torch.autograd.grad(pred_coeff_u[batch],var_p,create_graph = True,grad_outputs=torch.ones(pred_coeff_u[batch].size()).cuda())[0]
            u_x = d_u[...,0]
            d_xx_u = torch.autograd.grad(u_x,var_p,create_graph = True,grad_outputs=torch.ones(u_x.size()).cuda())[0]
            u_xx = d_xx_u[...,0]
            loss_eq += loss_func((-EPS*u_xx+b*u_x), f_value[batch][0]).clone()
        loss_eq/=f_value.shape[0]
    val_bdry=model(bdry_p.cuda(), f_value).squeeze(1)
    loss_bdry=loss_func(val_bdry,torch.zeros(val_bdry.shape).cuda())
    return loss_eq+loss_bdry, pred_coeff_u

def rel_L2_error(pred, true):
    return (torch.sum((true-pred)**2, dim=-1)/torch.sum((true)**2, dim=-1))**0.5

def log_gparams(gparams):
    cwd = os.getcwd()
    os.chdir(PATH)
    with open('parameters.txt', 'w') as f:
        for k, v in gparams.items():
            if k == 'losses':
                df = pd.DataFrame(gparams['losses'])
                df.to_csv('losses.csv')
            else:
                entry = f"{k}:{v}\n"
                f.write(entry)
    os.chdir(cwd)


def log_path(path):
    with open("../../paths.txt", "a") as f:
        f.write(str(path) + '\n')
        f.close()
log_path(PATH)
log_gparams(gparams)
################################################
time0 = time.time()
losses=[]
train_rel_L2_errors=[]
test_rel_L2_errors=[]
for epoch in range(1, EPOCHS+1):
    time_one_epoch=time.time()
    model_FEM.train()
    loss_total = 0
    num_samples=0
    train_rel_L2_error = 0

    for batch_idx, sample_batch in enumerate(trainloader):
        optimizer.zero_grad()
        coeff_u = sample_batch['coeff_u'].cuda()
        f_value = sample_batch['f_value'].cuda()
        
        loss,u_pred = closure(model_FEM, f_value)
        
        loss.backward()  

        optimizer.step(loss.item)
        loss_total += np.round(float(loss.item()), 4)
        num_samples += coeff_u.shape[0]
    
        with torch.no_grad():
            model_FEM.eval()
            _,u_pred = closure(model_FEM, f_value)
            u_pred=u_pred.squeeze().detach().cpu()
            coeff_u=coeff_u.squeeze().detach().cpu()
            train_rel_L2_error += torch.sum(rel_L2_error(u_pred, coeff_u))
   
    train_rel_L2_error /= num_samples
    time_one_epoch=time.time()-time_one_epoch
    if epoch<10:
        print("1Epoch takes {} seconds".format(time_one_epoch))
    #if epoch%5==0:
    if epoch%500==0:
        ## Test
        num_samples=0
        test_rel_L2_error = 0
        for batch_idx, sample_batch in enumerate(validateloader):
            with torch.no_grad():
                model_FEM.eval()
                coeff_u = sample_batch['coeff_u'].cuda()
                f_value = sample_batch['f_value'].cuda()
                _,u_pred = closure(model_FEM, f_value)
                u_pred=u_pred.squeeze().detach().cpu()
                coeff_u=coeff_u.squeeze().detach().cpu()
                test_rel_L2_error += torch.sum(rel_L2_error(u_pred, coeff_u))

                num_samples += coeff_u.shape[0]
        test_rel_L2_error /= num_samples
        
        ##Save and print
        losses.append(loss_total)
        train_rel_L2_errors.append(train_rel_L2_error)
        test_rel_L2_errors.append(test_rel_L2_error)
        torch.save({'model_state_dict': model_FEM.state_dict(),
                    'losses': losses,
                    'train_rel_L2_errors': train_rel_L2_errors,
                    'test_rel_L2_errors': test_rel_L2_errors
        }, PATH + '/model.pt')
        print("Epoch {0:4d}: loss {1:1.8f}, train_rel_error {2:.5f}, test_rel_error {3:.5f}".format(epoch, loss_total, train_rel_L2_error, test_rel_L2_error))

        
torch.save({'model_state_dict': model_FEM.state_dict(),
            'losses': losses,
            'train_rel_L2_errors': train_rel_L2_errors,
            'test_rel_L2_errors': test_rel_L2_errors
}, PATH + '/model.pt')
        
train_t=time.time()-time0
NPARAMS = sum(p.numel() for p in model_FEM.parameters() if p.requires_grad)

gparams['train_time'] = train_t
gparams['nParams'] = NPARAMS
gparams['batchSize_train'] = BATCH_SIZE_train
gparams['batchSize_validate'] = BATCH_SIZE_validate
gparams['path'] = PATH

log_gparams(gparams)
