#!/usr/bin/env python
# coding: utf-8

# In[11]:


import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader 
import numpy as np
from scipy import io
import matplotlib.pyplot as plt
import argparse
import os
import copy
import time

def cal_domain_grad(model, XTGrid, device):
    XTGrid = XTGrid.to(device)
    uf = model.forward(XTGrid)[:,0]
    uf_x, uf_t = torch.autograd.grad(outputs=uf.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(uf.shape).to(device), 
                               create_graph = True,
                               allow_unused=True)[0].T
    
    uf_xx = torch.autograd.grad(outputs=uf_x.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(uf_x.shape).to(device),
                               create_graph = True,
                               allow_unused=True)[0][:,0]
    
    uf_xxx = torch.autograd.grad(outputs=uf_xx.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(uf_x.shape).to(device),
                               create_graph = True,
                               allow_unused=True)[0][:,0]
    
    loss =  (uf_t+uf*uf_x+0.0025*uf_xxx)**2

    mean_x, mean_t = torch.autograd.grad(outputs=loss.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(loss.shape).to(device),
                               create_graph = True,
                               allow_unused=True)[0].T
    grad = torch.concatenate((mean_x.reshape(-1,1), mean_t.reshape(-1,1)), axis = 1)
    return grad

class LPINNSampler():
    def __init__(self, fixed_uniform, device, Nf, X_star, patience=10, min_delta=0.001, beta=0.2, tau=0.002, L=1):
        self.device = device
        self.cnt = 0
        self.beta = beta
        self.tau = tau
        self.L = L
        self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)
        self.initial = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)
        self.Nf = Nf
        self.X_star = torch.tensor(X_star, dtype=torch.float32, requires_grad=True).to(self.device)
        self.traj = []
        self.criterion = nn.MSELoss()

    def update(self, phy_lf, model):
        x_data = self.XTGrid
        samples = x_data.clone().detach().requires_grad_(True)
        
        for t in range(1, self.L + 1):
            grad = phy_lf(model, samples, self.device)
            scaler = torch.sqrt(torch.sum((grad+1e-16)**2, axis = 1)).reshape(-1,1)
            grad = grad / scaler
            with torch.no_grad():
                samples = samples + self.tau * grad + self.beta*torch.sqrt(torch.tensor(2 * self.tau, device=self.device)) * torch.randn(samples.shape, device=self.device)
                samples[:, 0] = torch.clamp(samples[:, 0], min=-1, max=1)  # x-axis clamping
                samples[:, 1] = torch.clamp(samples[:, 1], min=0, max=1)   # t-axis clamping
            samples = samples.clone().detach().requires_grad_(True)
        self.XTGrid = samples.detach()

class L_INFSampler():
    def __init__(self, device, Nf, step_size = 0.05 , n_iter = 20):
        self.device = device
        self.step_size = step_size
        self.n_iter = n_iter
        self.Nf = Nf

    def update(self, grad_f, model):
        x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)
        t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)
        x_t_new = torch.concatenate((x_new,t_new), axis = 1)
        self.XTGrid = x_t_new
            
        x_data = self.XTGrid
        samples = x_data.clone().detach().requires_grad_(True)
    
        for t in range(1, self.n_iter + 1):
            grad = grad_f(model, samples, self.device)
            with torch.no_grad():
                samples = samples + self.step_size * torch.sign(grad)
                samples[:, 0] = torch.clamp(samples[:, 0], min=-1, max=1)  # x축 클램핑
                samples[:, 1] = torch.clamp(samples[:, 1], min=0, max=1)   # t축 클램핑
            samples = samples.clone().detach().requires_grad_(True)
        self.XTGrid = samples.detach()        

class R3Sampler(nn.Module):
    def __init__(self,Nf, fixed_uniform, X_star, device):
        super(R3Sampler, self).__init__()
        
        self.Nf = Nf
        self.device = device
        self.XTGrid = copy.deepcopy(fixed_uniform)
        self.XTGrid = torch.tensor(self.XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) 
        self.X_star = X_star
    
    def update(self, loss):
        with torch.no_grad():
            mean_loss = loss.mean()
            mask = loss > mean_loss
            mask = mask.to('cpu')
            self.XTGrid = self.XTGrid[mask].detach()
            need_n_sample = self.Nf-self.XTGrid.shape[0]
            x_new = torch.zeros(need_n_sample, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)
            t_new = torch.zeros(need_n_sample, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)
            x_t_new = torch.concatenate((x_new,t_new), axis = 1)
            self.XTGrid = torch.concatenate((self.XTGrid, x_t_new), axis = 0)
            self.XTGrid = torch.tensor(self.XTGrid, dtype = torch.float32, device=self.device, requires_grad=True)
    
class RADSampler(nn.Module):
    def __init__(self, k, c, device, Nf):
        super(RADSampler, self).__init__()
        self.device = device
        self.Nf = Nf
        self.k = k
        self.c = c
        x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)
        t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)
        x_t_new = torch.concatenate((x_new,t_new), axis = 1)
        self.XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True)  
        
    def update(self, model):
        x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)
        t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)
        x_t_new = torch.concatenate((x_new,t_new), axis = 1)
        XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True)
        XTGrid = XTGrid.to(self.device)
        uf = model.forward(XTGrid)[:,0]
        uf_x, uf_t = torch.autograd.grad(outputs=uf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf.shape).to(self.device), 
                                   create_graph = True,
                                   allow_unused=True)[0].T
        uf_xx = torch.autograd.grad(outputs=uf_x.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf_x.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,0]

        uf_xxx = torch.autograd.grad(outputs=uf_xx.to(device), 
                           inputs=XTGrid, 
                           grad_outputs=torch.ones(uf_x.shape).to(device),
                           create_graph = True,
                           allow_unused=True)[0][:,0]

        err = torch.abs((uf_t+uf*uf_x+0.0025*uf_xxx))
        err = (err**self.k)/((err**self.k).mean())+self.c
        err_norm = err/(err.sum())

        indice = torch.multinomial(err_norm, self.Nf, replacement = True)
        self.XTGrid = XTGrid[indice]

    
class PINN(nn.Module):
    def __init__(self,k , c , t, X_star, u_star, exact_u, space_domain, time_domain, Layers, N0, Nb, Nf, Nt, beta, tau, L,
                 Activation = nn.Tanh(), model_name = "PINN.model", device = 'cpu', display_freq=1000, sampling_method = 'fixed'):
        super(PINN, self).__init__()
        
        LBs = [space_domain[0], time_domain[0]]
        UBs = [space_domain[1], time_domain[1]]
        
        self.LBs = torch.tensor(LBs, dtype=torch.float32).to(device)
        self.UBs = torch.tensor(UBs, dtype=torch.float32).to(device)
        
        self.Layers = Layers
        self.in_dim  = Layers[0]
        self.out_dim = Layers[-1]
        self.Activation = Activation
        
        self.device = device
        
        x_init = torch.zeros(Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)
        t_init = torch.zeros(Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)
        x_t_init = torch.concatenate((x_init,t_init), axis = 1)
        self.fixed_uniform = torch.tensor(x_t_init, dtype = torch.float32, device=self.device, requires_grad=True)
        
        self.N0 = N0
        self.Nb = Nb
        self.Nf = Nf
        self.Nt = Nt
        self.beta = beta
        self.tau = tau
        self.L = L
        self.t = t
        self.X_star = X_star
        self.u_star = u_star
        self.exact_u = exact_u
        
        self.XT0, self.u0  = self.initial_condition(self.LBs[0], self.UBs[0])
        self.XTbL, self.XTbU = self.boundary_condition( self.LBs[0], self.UBs[0])
        
        self.XT0 = self.XT0.to(device)
        self.u0 = self.u0.to(device) 
        
        self.XTbL = self.XTbL.to(device) 
        self.XTbU = self.XTbU.to(device)
        
        self._nn = self.build_model()
        self._nn.to(self.device)
        self.Loss = torch.nn.MSELoss(reduction='mean')
        
        self.model_name = model_name
        self.display_freq = display_freq
        
        self.k = k
        self.c = c
        self.method = sampling_method
        self.r3_sampler = R3Sampler(self.Nf, self.fixed_uniform, self.X_star, device=self.device)
        self.lpinn_sampler = LPINNSampler(fixed_uniform=self.fixed_uniform, device=self.device, Nf=self.Nf, beta=self.beta, tau=self.tau, L = self.L, X_star=self.X_star)
        self.l_inf_sampler = L_INFSampler(device=self.device, Nf=self.Nf)
        self.rad_sampler = RADSampler(k=self.k, c=self.c, device = self.device, Nf=self.Nf)
        
        try:
            os.mkdir("../models/"+self.method)
        except:
            print('Folder exists!!')
    
    def build_model(self):
        seq = nn.Sequential()
        for layer_idx in range(len(self.Layers)-1):
            this_module = nn.Linear(self.Layers[layer_idx], self.Layers[layer_idx+1])
            nn.init.xavier_normal_(this_module.weight)
            seq.add_module("Linear" + str(layer_idx), this_module)
            if not layer_idx == len(self.Layers)-2:
                seq.add_module("Activation" + str(layer_idx), self.Activation)
        return seq    
    
    def forward(self, x):
        x = x.to(self.device)
        x = x.reshape((-1,self.in_dim))
        return torch.reshape(self._nn.forward(x), (-1, self.out_dim))
    
    def initial_condition(self,LB, UB):
        x = torch.tensor([])

        if (type(LB) != type(x)):
            LB = torch.tensor(LB).cpu()
        else:
            LB = LB.cpu()
        if (type(UB) != type(x)):
            UB = torch.tensor(UB).cpu()
        else:
            UB = UB.cpu()

        indices = (self.X_star[:,0] >= LB) & (self.X_star[:,0] < UB) & (self.X_star[:,1] == 0.)
        XT0 = self.X_star[indices]
        u0 = self.u_star[indices]

        return XT0, u0

    def boundary_condition(self, LB, UB):
        x = torch.tensor([])
        
        if (type(LB) != type(x)):
            LB = torch.tensor(LB).cpu()
        else:
            LB = LB.cpu()
        if (type(UB) != type(x)):
            UB = torch.tensor(UB).cpu()
        else:
            UB = UB.cpu()
        
        tb =  torch.tensor(np.linspace(0, 1, self.t.shape[0], endpoint=False), dtype = torch.float32)
        XTL = torch.cat(( LB*torch.ones((self.t.shape[0],1)), tb.reshape(-1,1)), dim = 1)
        XTL.requires_grad_()
        XTU = torch.cat(( UB*torch.ones((self.t.shape[0],1)), tb.reshape(-1,1)), dim = 1)
        XTU.requires_grad_()
        
        return  XTL, XTU
    
    def IC_loss(self):
        XT0 = self.XT0
        u0  = self.u0
        UV0_pred = self.forward(XT0)
        u0_pred = UV0_pred[:,0].reshape(-1)
        return self.Loss(u0_pred, u0)

    def BC_loss(self):
        ub_l = self.forward(self.XTbL)
        ub_u = self.forward(self.XTbU)
        
        return torch.mean((ub_l-ub_u)**2)
    
    def classical_sampling(self):
        method = self.method
        if method =='fixed':
            XTGrid = torch.tensor(self.fixed_uniform, dtype = torch.float32, device=self.device, requires_grad=True) 
            
        elif method =='rar':
            x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)
            t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)
            x_t_new = torch.concatenate((x_new,t_new), axis = 1)
            XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True) 
        
        return XTGrid
        
    def physics_loss(self, XTGrid):
        XTGrid = XTGrid.to(self.device)
        uf = self.forward(XTGrid)[:,0]
        uf_x, uf_t = torch.autograd.grad(outputs=uf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf.shape).to(self.device), 
                                   create_graph = True,
                                   allow_unused=True)[0].T
        uf_xx = torch.autograd.grad(outputs=uf_x.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf_x.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,0]

        uf_xxx = torch.autograd.grad(outputs=uf_xx.to(device), 
                           inputs=XTGrid, 
                           grad_outputs=torch.ones(uf_x.shape).to(device),
                           create_graph = True,
                           allow_unused=True)[0][:,0]
        
        loss =  (uf_t+uf*uf_x+0.0025*uf_xxx)**2
        
        return loss

    def train(self, n_iters, weights=(1.0,1.0,1.0)):
        params = list(self.parameters())
        optimizer = optim.Adam(params, lr=1e-3)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 5000, gamma=0.9, last_epoch=-1)
        min_loss = 999999.0
        training_losses = [-10]*n_iters
        rel_error = [-10]*(1+n_iters//1000)
        start = time.time()
        for it in range(n_iters):
            total_ICloss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)
            total_BCloss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)
            total_physics_loss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)
            
            total_ICloss = total_ICloss + self.IC_loss()
            total_BCloss = total_BCloss + self.BC_loss()
            
            if self.method =='r3':
                if it == 0:
                    XTGrid = self.r3_sampler.XTGrid
                    XTGrid = torch.tensor(XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) 
                else:
                    with torch.no_grad():
                        self.r3_sampler.update(loss)
                        XTGrid = self.r3_sampler.XTGrid
                        XTGrid = torch.tensor(XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) 
                        
            elif self.method == 'lpinn':
                if it == 0:
                    XTGrid = self.lpinn_sampler.XTGrid
                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)
                else:
                    if self.lpinn_sampler.cnt % 1 == 0:
                        self.lpinn_sampler.update(cal_domain_grad, self._nn)
                    XTGrid = self.lpinn_sampler.XTGrid
                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)
                self.lpinn_sampler.cnt += 1
            
            elif self.method == 'l_inf':
                self.l_inf_sampler.update(cal_domain_grad, self._nn)
                XTGrid = self.l_inf_sampler.XTGrid
                XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)                
            
            elif self.method == 'rad':
                self.rad_sampler.update(self._nn)
                XTGrid = self.rad_sampler.XTGrid
                XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)
            
            else:
                XTGrid = self.classical_sampling()
                
            optimizer.zero_grad()    
            loss = self.physics_loss(XTGrid)
            mean_loss = loss.mean()
            
            total_physics_loss = total_physics_loss + mean_loss
            total_loss = weights[0]*total_ICloss + weights[1]*total_BCloss\
                        + weights[2]*total_physics_loss 
            
            total_loss.backward()
            optimizer.step()
            scheduler.step()
            if total_loss < min_loss:
                torch.save(self, self.model_name)
                min_loss = float(total_loss)
                    
            training_losses[it] = float(total_loss)
            
            if (it+1) % self.display_freq == 0:
                with torch.no_grad():
                    outputs = self.forward(self.X_star)
                    outputs = outputs.reshape(101,200)
                    re = np.linalg.norm(exact_u.cpu()-outputs.cpu().detach()) / np.linalg.norm(exact_u.cpu().detach())
                    rel_error[int((it+1)/1000)] = float(re*100)
                print("Iteration Number = {}".format(it+1))
                print("\tIC Loss = {}".format(float(total_ICloss)))
                print("\tBC Loss = {}".format(float(total_BCloss)))
                print("\tPhysics Loss = {}".format(float(total_physics_loss)))
                print("\tTraining Loss = {}".format(float(total_loss)))
                print("\tRelative L2 error (test) = {}".format(float(re*100)))

                torch.save(XTGrid, "../models/"+self.method +'/'+str(self.Nf)+"_grid_"+str(it+1)+'_'+ str(self.beta)+'_'+ str(self.L)+'_'+str(i))
                torch.save(exact_u.cpu()-outputs.cpu().detach(), "../models/"+self.method +'/'+str(self.Nf)+"_error_"+str(it+1)+'_'+ str(self.beta)+'_'+ str(self.L)+'_'+str(i))
                
        end = time.time()
        print(end-start)
        return training_losses, rel_error

if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument('--nodes', type=int, default = 128, help='The number of nodes per hidden layer in the neural network')
        parser.add_argument('--N0', type=int, default = 100, help='The number of points to use on the initial condition')
        parser.add_argument('--Nb', type=int, default = 100, help='The number of points to use on the boundary condition')
        parser.add_argument('--Nf', type=int, default = 1000, help='The number of collocation points to use')
        parser.add_argument('--Nt', type=int, default = 1000, help='The number of points to use to calculate the MSE loss')
        parser.add_argument('--epochs', type=int, default = 200000, help='The number of epochs to train the neural network')
        parser.add_argument('--model-name', type=str, default='PINN_model', help='File name to save the model')
        parser.add_argument('--display-freq', type=int, default=1000, help='How often to display loss information')
        parser.add_argument('--layers', type=int, default = 4, help='The number of hidden layers in the neural network')
        parser.add_argument('--beta', type=int, default=0.2, help='High residual concentration parameter')
        parser.add_argument('--tau', type=int, default=0.002, help='Langevin step size')
        parser.add_argument('--L', type=int, default=1, help='The number of Langevin iteration')
        parser.add_argument('--method', type=str, default='lpinn', help='sampling method')
        parser.add_argument('-f')
        args = parser.parse_args()

        data = np.load('../data/KDV.npz')
        
        x = torch.tensor(data['input_x'][:200].reshape(-1,1), dtype = torch.float32)
        t = torch.tensor(data['input_t'].reshape(101,-1)[:,0].reshape(-1, 1), dtype = torch.float32)
        exact_u = torch.tensor(data['output'].reshape(101,200), dtype = torch.float32)

        X, T = np.meshgrid(x,t)
        X_star = torch.tensor(np.hstack((X.flatten()[:,None], T.flatten()[:,None])), dtype = torch.float32)
        u_star = torch.flatten(exact_u)
        
        if not os.path.exists("../models/"):
            os.mkdir("../models/")

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        NHiddenLayers = args.layers
        
        boundaries = [-1, 1]
        t_domain = [0., 1.]
        
        Layers = [2] + [args.nodes] * NHiddenLayers + [1]
        Activation = nn.Tanh()

        k = 1
        c = 1
        
        repeat = [0,1,2,3,4]
        sampling_me = args.method
        
        for i in repeat:
            pinn = PINN(  k = k,
                          c = c,
                          t= t,
                          X_star = X_star,
                          u_star = u_star,
                          exact_u = exact_u,
                          space_domain = boundaries,
                          time_domain = t_domain,
                          Layers = Layers,
                          N0 = args.N0,
                          Nb = args.Nb,
                          Nf = args.Nf,
                          Nt = args.Nt,
                          beta = args.beta,
                          tau = args.tau,
                          L = args.L,
                          Activation = Activation,
                          device = device,
                          model_name = "../models/" + args.model_name + ".model_"+sampling_me+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+ str(args.beta)+'_'+ str(args.L)+'_' +str(i),
                          display_freq = args.display_freq, sampling_method = sampling_me )

            Losses_train, Losses_rel_l2 = pinn.train(args.epochs, weights = (100, 1, 1)) # initial, boundary, residual

            torch.save(Losses_train, "../models/" + args.model_name + ".loss_"+sampling_me+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+ str(args.beta)+'_'+ str(args.L) + '_'+str(i))
            torch.save(Losses_rel_l2, "../models/" + args.model_name + ".rel_l2_"+sampling_me+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+ str(args.beta)+'_'+ str(args.L)+ '_'+ str(i))

