import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import math
import torchdiffeq as ode
import torch.nn.utils as utils
import functools
import seaborn as sns
from utils.cubic_spline import CSpline
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")

class Sepi(nn.Module):  # myModel
    def __init__(self, device):
        super(Sepi, self).__init__()
        self.device = device
        input_size = 8
        hidden_size = 64
        output_size = 1
        self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, output_size, bias=True))
        
        for m in self.mlp.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
    
    def forward(self, s_batch):
        mlp = self.mlp
        FF = mlp(s_batch)
        
        '''
        FF = s_batch[...,0] - s_batch[...,1]*s_batch[...,7] +\
            s_batch[...,3]*s_batch[...,6] + 0.001*(s_batch[...,1]**2+s_batch[...,3]**2)
        '''
        return(FF)

class NODE(nn.Module):  # myModel
    def __init__(self, kks):
        super(NODE, self).__init__()
        self.csp = None
        self.mlp = None
        self.f0 = None
        self.kks = kks
    
    def forward(self, t, s_b):
        mlp = self.mlp; csp = self.csp; f0 = self.f0; kks = self.kks
        ut = csp.fit(t).expand(f0.shape)
        u_b = f0*ut
        
        s_tildex = torch.fft.fft2(s_b,axis=(0,1))
        s_batch = u_b.unsqueeze(-1)
        for i in range(len(kks)):
            kksi = kks[i]
            sxs = torch.fft.ifft2(kksi*s_tildex,axis=(0,1)).real
            s_batch = torch.cat((s_batch,sxs.unsqueeze(-1)), axis=-1)
            
        FF = mlp(s_batch).squeeze(-1)
        return(FF)

class Model():
    def __init__(self, args):
        self.args = args
        
        self.y_train,self.X_train,self.f0_train,self.y_valid,self.X_valid,\
            self.f0_valid, self.xs,self.ys,self.ts = self.read_data()
        
        xs = self.xs
        k_max = math.floor(len(xs)/2.0)
        k_y = 2*math.pi/args.L*torch.cat((torch.arange(start=0, end=k_max, step=1, device=args.device),\
                torch.arange(start=-k_max, end=0, step=1, device=args.device)), 0).repeat(len(xs),1)
        k_x = k_y.transpose(0,1) 
        #Negative Laplacian in Fourier space
        lap = k_x**2 + k_y**2
        lap[0,0] = 1.0
        kx1 = (1j*k_x)**1; kx2 = (1j*k_x)**2
        ky1 = (1j*k_y)**1; ky2 = (1j*k_y)**2
        kxy = (1j*k_x)*(1j*k_y)
        #self.zlap = self.ctom((1.0+0j)*lap)
        dlap = (1.0+0j)/lap; xdlap = kx1 * dlap; ydlap = ky1 * dlap
        kks = [kx1, kx2, ky1, ky2, kxy, xdlap, ydlap]
        self.kks = kks
        
        self.sepi = Sepi(args.device).to(args.device)
    
    def read_data(self):
        args = self.args
        lt = 1
        filename = args.filename
        X_train = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        X_valid = pd.read_csv("./dataset/data_"+filename+"/u_valid.csv").values[:,1:]
        y_train = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        y_valid = pd.read_csv("./dataset/data_"+filename+"/s_valid.csv").values[:,1:]
        f0_train = pd.read_csv("./dataset/data_"+filename+"/f0_train.csv").values[:,1:]
        f0_valid = pd.read_csv("./dataset/data_"+filename+"/f0_valid.csv").values[:,1:]
        xs = pd.read_csv("./dataset/data_"+filename+"/x.csv").values[:,1]
        ys = pd.read_csv("./dataset/data_"+filename+"/y.csv").values[:,1]
        ts = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1] 
        
        #noise = args.sigma_sd * np.random.randn(y_train.shape[0], y_train.shape[1]) + 0
        noise = args.sigma_sd*np.abs(y_train).mean() * np.random.randn(y_train.shape[0], y_train.shape[1]) + 0
        y_train = y_train + noise
        y_train = y_train.reshape(-1,len(ts),len(xs),len(ys))[:args.num]
        y_valid = y_valid.reshape(-1,len(ts),len(xs),len(ys))
        f0_train = f0_train.reshape(-1,len(xs),len(ys))[:args.num]
        f0_valid = f0_valid.reshape(-1,len(xs),len(ys))
        X_train = torch.tensor(X_train)[:args.num]
        y_train = torch.tensor(y_train)[:,:,::lt]
        f0_train = torch.tensor(f0_train)
        X_valid = torch.tensor(X_valid)
        y_valid = torch.tensor(y_valid)[:,:,::lt]
        f0_valid = torch.tensor(f0_valid)
        xs = torch.tensor(xs)
        ys = torch.tensor(ys)
        ts = torch.tensor(ts)
        
        ts = ts[::lt]
        return(y_train,X_train,f0_train,y_valid,X_valid,f0_valid,xs,ys,ts)
    
    def train(self):
        args = self.args
        X_train = self.X_train; y_train = self.y_train; f0_train = self.f0_train
        X_valid = self.X_valid; y_valid = self.y_valid; f0_valid = self.f0_valid
        xs = self.xs; ys = self.ys; ts = self.ts; kks = self.kks
        sepi = self.sepi
        
        epochs = args.epochs
        batch_size = args.batch_size
        device = args.device
        
        params = sepi.parameters()
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        sepi.train() 
        
        # Data Preprocessing
        N = X_train.shape[0]
        u0 = f0_train.unsqueeze(1).expand(N,len(ts),len(xs),len(ys))
        ut = X_train.unsqueeze(-1).unsqueeze(-1).expand(N,len(ts),len(xs),len(ys))
        u = u0*ut
        s = y_train
        
        Nv = X_valid.shape[0]
        u0_valid = f0_valid.unsqueeze(1).expand(Nv,len(ts),len(xs),len(ys))
        ut_valid = X_valid.unsqueeze(-1).unsqueeze(-1).expand(Nv,len(ts),len(xs),len(ys))
        uv = u0_valid*ut_valid
        sv = y_valid
        
        # train
        tle = args.tle
        numz = args.numz 
        prob = args.prob
        best_loss = torch.tensor(1e8)
        
        Tr_loss = []
        Va_loss = []
        for epoch in range(epochs):
            optimizer.zero_grad()
            sn = torch.randint(0,N,(batch_size,))
            st = torch.randint(0,len(ts)-tle,(1,))
            u_b = (u[sn, st:st+tle]).to(device)
            s_b = (s[sn, st:st+tle]).to(device)
            
            s_tildex = torch.fft.fft2(s_b,axis=(2,3))
            s_batch = u_b.unsqueeze(-1)
            for i in range(len(kks)):
                kksi = kks[i].unsqueeze(0).unsqueeze(0).expand(len(sn),tle,len(xs),len(ys))
                sxs = torch.fft.ifft2(kksi*s_tildex,axis=(2,3)).real
                s_batch = torch.cat((s_batch,sxs.unsqueeze(-1)), axis=-1)
            right = sepi(s_batch).squeeze(-1)
            
            tL = ts[tle-1]*(tle+1)/tle            
            tkk = (torch.cat((torch.arange(0,tle/2+1), torch.arange(-tle/2+1, 0)), axis=0)) * 2*torch.pi/tL
            tkk2 = tkk.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(batch_size,tle,len(xs),len(ys)).to(device)  
            fty = torch.fft.fft(s_b,axis=1)
            filt = (tle-numz)//2 
            fty[:,tle//2-filt:tle//2+filt] = 0 + 0j
            left = torch.fft.ifft(1j*tkk2*fty,axis=1).real
            
            trunc = tle//prob
            loss = torch.mean(((left-right)[:,trunc:-trunc])**2)
            #print(loss)
            
            loss.backward()
            optimizer.step()   
            Tr_loss.append(loss.item())
            
            if epoch%args.outime == 0:
                with torch.no_grad():
                    #print(epoch, loss.item())
                    s_tildex = torch.fft.fft2(sv,axis=(2,3))
                    s_batch = uv.unsqueeze(-1)
                    for i in range(len(kks)):
                        kksi = kks[i].unsqueeze(0).unsqueeze(0).expand(Nv,len(ts),len(xs),len(ys)).cpu()
                        sxs = torch.fft.ifft2(kksi*s_tildex,axis=(2,3)).real
                        s_batch = torch.cat((s_batch,sxs.unsqueeze(-1)), axis=-1)
                    
                    #right = sepi(s_batch.to(device)).squeeze(-1)
                    bs = 10
                    right = sepi(s_batch[:bs].to(device)).squeeze(-1).cpu()
                    for i in range(1,Nv//bs):
                        rii = sepi(s_batch[bs*i:bs*(i+1)].to(device)).squeeze(-1).cpu()
                        right = torch.cat((right, rii),axis=0)
                    
                    tl = len(ts)
                    tL = ts[-1]*(tl+1)/tl            
                    tkk = (torch.cat((torch.arange(0,tl/2+1), torch.arange(-tl/2+1, 0)), axis=0)) * 2*torch.pi/tL
                    tkk2 = tkk.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(Nv,tl,len(xs),len(ys))  
                    fty = torch.fft.fft(sv,axis=1)
                    filt = (tle-numz)//2 
                    fty[:,tl//2-filt:tl//2+filt] = 0 + 0j
                    left = torch.fft.ifft(1j*tkk2*fty,axis=1).real
                
                    trunc = tl//prob
                    valid_loss = torch.mean(((left-right)[:,trunc:-trunc])**2)
                    
                    print('Iter {:04d} | '.format(epoch) + 'Train Loss {:.6f}; Valid Loss {:.6f}'.format(loss.item(), valid_loss.item()))
                    Va_loss.append(valid_loss.item())
                    
                    if valid_loss < best_loss:
                        best_loss = valid_loss
                        torch.save(sepi, './model/'+args.model_ind+'.pkl')
           
        Trdf = pd.DataFrame(Tr_loss)
        Vadf = pd.DataFrame(Va_loss)
        Trdf.to_csv("./results/Tr_loss_ns.csv")
        Vadf.to_csv("./results/Va_loss_ns.csv")
    
    def test(self):
        args = self.args
        xs = self.xs; ys = self.ys; ts = self.ts; kks = self.kks
        
        filename = args.filename
        X_test = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        y_test = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        f0_test = pd.read_csv("./dataset/data_"+filename+"/f0_test.csv").values[:,1:]
        X_test = torch.tensor(X_test)
        y_test = torch.tensor(y_test.reshape(-1,len(ts),len(xs),len(ys)))
        f0_test = torch.tensor(f0_test.reshape(-1,len(xs),len(ys)))
        
        kkst = []
        for kksi in kks:
            kkst.append(kksi.cpu())
        kks = kkst
        
        prob = args.prob
        tle = len(ts)  
        trunc = tle//prob
        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        
        # Data Preprocessing
        N = X_test.shape[0]
        s = y_test.unsqueeze(-1)
        
        node = NODE(kks)
        node.mlp = sepi.mlp.cpu()
        
        errors = torch.zeros(N)
        with torch.no_grad():
            for Ni in range(N):
                ni = torch.tensor([Ni])
                s_b = s[ni]

                # node
                csp_test = CSpline(ts.cpu(), X_test[ni][0].cpu())
                node.csp = csp_test
                node.f0 = f0_test[ni].squeeze(0)
                
                s0 = s_b[0,trunc:trunc+1,:,:,0]
                print('node',str(Ni),'...')
                preds =torch.zeros(len(ts),len(xs),len(ys))
                start = time.time()
                #
                pred = ode.odeint(node, s0[0], ts[trunc:], rtol=args.rtol, atol=args.atol, method=args.method)
                preds[trunc:] = pred
                #
                end = time.time()  
                error = (preds-s_b[0].squeeze(-1))[trunc:].abs().mean()
                errors[Ni] = error
                print("test time=",round(end-start,2), 'error=', error)
                
                self.draw(X_test, y_test, Ni, preds, f0_test)
                
        pd_er = pd.DataFrame(errors.numpy())
        pd_er.to_csv("./results/"+args.model_ind+"errors.csv")    
        print('Testing error:',torch.mean(errors))
                
         
    def draw(self, X_test, y_test, Ni, preds, f0_test):
        xs = self.xs; ys = self.ys; ts = self.ts
        xss = np.round(xs.cpu().numpy(),3)
        yss = np.round(ys.cpu().numpy(),3)
        tss = [199,599,999]
        
        zh = 32
        font1 = {'family':'Times New Roman', 'weight':'normal','size':zh}
        
        with plt.ioff():
            u = y_test[Ni].cpu().numpy()
            pru_tr = np.round(preds.detach().cpu().numpy(),3)
            sensor = X_test[Ni].cpu().numpy()
            
            fig = plt.figure(figsize=(30,20))
            cmapc = "rainbow"
            j = 1
            for i in range(len(tss)):
                ax1 = fig.add_subplot(3,len(tss),j); j+=1
                data=pd.DataFrame(u[tss[i]], index=xss, columns=yss)
                cmap=sns.heatmap(data,center=0,cmap=cmapc)
                
            for i in range(len(tss)):
                ax1 = fig.add_subplot(3,len(tss),j); j+=1
                data=pd.DataFrame(pru_tr[tss[i]], index=xss, columns=yss)
                cmap=sns.heatmap(data,center=0,cmap=cmapc)
                
            for i in range(len(tss)):
                ax1 = fig.add_subplot(3,len(tss),j); j+=1
                data=pd.DataFrame(np.abs(u[tss[i]]-pru_tr[tss[i]]), index=xss, columns=yss)
                cmap=sns.heatmap(data,center=0,vmin=0,cmap=cmapc)

            plt.savefig("figures/"+str(Ni)+"a.png")

        with plt.ioff():
            fig = plt.figure(figsize=(20,15)) 
            gs = fig.add_gridspec(2,2)
            ax1 = fig.add_subplot(gs[0, 0])
            data=pd.DataFrame(y_test[Ni,0], index=xss, columns=yss)
            cmap=sns.heatmap(data,center=0,cmap=cmapc)
            
            ax2 = fig.add_subplot(gs[0, 1])
            data=pd.DataFrame(f0_test[Ni], index=xss, columns=yss)
            cmap=sns.heatmap(data,center=0,cmap=cmapc)
            
            ax3 = fig.add_subplot(gs[1, :])
            ax3.plot(ts,X_test[Ni])
            plt.tick_params(labelsize=zh)
            plt.ylabel(r"$u_4$", font1)
            plt.xlabel(r"$t$", font1)
            
            plt.savefig("figures/"+str(Ni)+"b.png")
            