import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import torchdiffeq as ode
import functools
import seaborn as sns
import torch.nn.utils as utils
from utils.cubic_spline import CSpline
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")

class Sepi(nn.Module):  # myModel
    def __init__(self, device):
        super(Sepi, self).__init__()
        self.device = device
        input_size = 6
        hidden_size = 64
        output_size = 1
        self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, output_size, bias=True))
        
        for m in self.mlp.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
    
    def forward(self, s_batch):
        mlp = self.mlp
        FF = mlp(s_batch)
        return(FF)

class NODE(nn.Module):  # myModel
    def __init__(self,u0,kk):
        super(NODE, self).__init__()
        self.csp = None
        self.mlp = None
        self.u0 = u0
        self.kks = kk.unsqueeze(-1)
    
    def forward(self, t, s_b):
        mlp = self.mlp; csp = self.csp; u0 = self.u0; kks = self.kks
        ut = csp.fit(t).expand(u0.shape)
        u_b = u0+ut
        s_tildex = torch.fft.fft(s_b,axis=0)
        sxs = []
        for i in range(5):
            sxs.append(torch.fft.ifft((1j*kks)**i*s_tildex,axis=0).real)
        s_batch = torch.cat((u_b,sxs[0],sxs[1],sxs[2],sxs[3],sxs[4]),axis=-1)
        FF = mlp(s_batch)
        return(FF)
    
class genc(nn.Module):
    def __init__(self, kk, u0):
        super(genc, self).__init__()
        self.D = 0.01
        self.K = 0.01
        self.kk = kk
        self.u0 = u0
        self.fit_data=None 
    
    def forward(self, t, x):
        u0 = self.u0
        ut = self.fit_data.fit(t).expand(u0.shape)
        u_b = u0+ut
        s_tilde = torch.fft.fft(x,axis=1)
        
        gu = self.D*torch.fft.ifft(-1*self.kk**2*s_tilde,axis=1).real +\
            self.K*x**2 + u_b #dr方程 
        return(gu)

class Model():
    def __init__(self, args):
        self.args = args
        
        self.y_train,self.X_train,self.y_valid,self.X_valid,self.y_test,self.X_test,\
            self.xs,self.xs_train,self.ts,self.lx = self.read_data()
        
        self.sepi = Sepi(args.device).to(args.device)
    
    def read_data(self):
        args = self.args
        lx = 2
        lt = 1
        device = args.device
        filename = args.filename
        QX_train = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        X_valid = pd.read_csv("./dataset/data_"+filename+"/u_valid.csv").values[:,1:] 
        X_test = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        Qy_train = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        y_valid = pd.read_csv("./dataset/data_"+filename+"/s_valid.csv").values[:,1:]
        y_test = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        xs = pd.read_csv("./dataset/data_"+filename+"/x.csv").values[:,1]
        ts = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1] 
        
        #noise = args.sigma_sd * np.random.randn(y_train.shape[0], y_train.shape[1]) + 0
        noise = args.sigma_sd*np.abs(Qy_train).mean() * np.random.randn(Qy_train.shape[0], Qy_train.shape[1]) + 0
        Qy_train = Qy_train + noise
        Qy_train = Qy_train.reshape(-1,len(xs),len(ts))[:args.num]
        y_valid = y_valid.reshape(-1,len(xs),len(ts))
        y_test = y_test.reshape(-1,len(xs),len(ts))
        
        X_train = torch.tensor(QX_train).to(device)[:args.num]
        y_train = torch.tensor(Qy_train)[:,::lx,::lt].to(device)
        X_valid = torch.tensor(X_valid).to(device)
        y_valid = torch.tensor(y_valid)[:,:,::lt].to(device)
        X_test = torch.tensor(X_test)
        y_test = torch.tensor(y_test)[:,:,::lt]
        xs = torch.tensor(xs).to(device)
        xs_train = torch.tensor(xs[::lx]).to(device)
        ts = torch.tensor(ts).to(device)
        
        ts = ts[::lt]
        return(y_train,X_train,y_valid,X_valid,y_test,X_test,xs,xs_train,ts,lx)
    
    def train(self):
        args = self.args
        lx = self.lx
        epochs = args.epochs
        batch_size = args.batch_size
        device = args.device
        L = args.L
        
        X_train = self.X_train; y_train = self.y_train
        X_valid = self.X_valid; y_valid = self.y_valid
        xs = self.xs; xs_train = self.xs_train; ts = self.ts
        sepi = self.sepi
        
        params = sepi.parameters()
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        sepi.train() 
        
        # Data Preprocessing
        N = X_train.shape[0]
        u0 = torch.sin(2*np.pi*xs_train/args.L).repeat(N,len(ts)).reshape(N,len(ts),-1).permute(0,2,1).to(device)
        ut = X_train.repeat(1,len(xs_train)).reshape(N,len(xs_train),-1)
        u = (ut+u0).unsqueeze(-1)
        s = y_train.unsqueeze(-1)
        
        kk = torch.cat((torch.arange(0, len(xs)/2),torch.tensor([0]),\
                        torch.arange(-len(xs)/2+1, 0)))*2*np.pi/L
        kk[len(xs)//(2*lx):-len(xs)//(2*lx)+1] = 0
        kk = kk.to(device)
        
        kk_train = torch.cat((torch.arange(0, len(xs_train)/2),torch.tensor([0]),\
                        torch.arange(-len(xs_train)/2+1, 0)))*2*np.pi/L
        kk_train = kk_train.to(device)
        # train
        tle = args.tle
        numz = args.numz 
        prob = args.prob
        best_loss = torch.tensor(1000)
        
        Tr_loss = []
        Va_loss = []
        for epoch in range(epochs):
            optimizer.zero_grad()
            sn = torch.randint(0,N,(batch_size,)).to(device)
            st = torch.randint(0,len(ts)-tle,(1,)).to(device)
            u_b = u[sn, :, st:st+tle]
            s_b = s[sn, :, st:st+tle]
            
            s_tildex = torch.fft.fft(s_b,axis=1)
            kks = kk_train.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(s_b.shape)
            sxs = []
            for i in range(5):
                sxs.append(torch.fft.ifft((1j*kks)**i*s_tildex,axis=1).real)
            s_batch = torch.cat((u_b,sxs[0],sxs[1],sxs[2],sxs[3],sxs[4]),axis=-1)
            right = sepi(s_batch).squeeze(-1)
            
            tL = ts[tle-1]*(tle+1)/tle            
            tkk = (torch.cat((torch.arange(0,tle/2+1), torch.arange(-tle/2+1, 0)), axis=0)).to(device) * 2*torch.pi/tL
            tkk2 = tkk.unsqueeze(0).unsqueeze(0).expand(batch_size,len(xs_train),tle)  
            fty = torch.fft.fft(s_b,axis=2).squeeze(-1)
            filt = (tle-numz)//2 
            fty[:,:,tle//2-filt:tle//2+filt] = 0 + 0j
            left = torch.fft.ifft(1j*tkk2*fty,axis=2).real
            
            trunc = tle//prob
            loss = torch.mean(((left-right)[:,:,trunc:-trunc])**2)
            
            loss.backward()
            optimizer.step()   
            Tr_loss.append(loss.item())
            
            if epoch%args.outime == 0:
                with torch.no_grad():
                    Nv = X_valid.shape[0]
                    u0 = torch.sin(2*np.pi*xs/args.L).repeat(Nv,len(ts)).reshape(Nv,len(ts),-1).permute(0,2,1).to(device)
                    ut = X_valid.unsqueeze(1).expand(Nv,len(xs),len(ts))
                    uv = (ut+u0).unsqueeze(-1)
                    sv = y_valid.unsqueeze(-1)
                    s_tildex = torch.fft.fft(sv,axis=1)
                    kks = kk.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(sv.shape)
                    sxs = []
                    for i in range(5):
                        sxs.append(torch.fft.ifft((1j*kks)**i*s_tildex,axis=1).real)
                    s_batch = torch.cat((uv,sxs[0],sxs[1],sxs[2],sxs[3],sxs[4]),axis=-1)
                    right = sepi(s_batch).squeeze(-1)
                    
                    tl = len(ts)
                    tL = ts[tl-1]*(tl+1)/tl            
                    tkk = (torch.cat((torch.arange(0,tl/2+1), torch.arange(-tl/2+1, 0)), axis=0)).to(device) * 2*torch.pi/tL
                    tkk2 = tkk.unsqueeze(0).unsqueeze(0).expand(Nv,len(xs),tl)  
                    fty = torch.fft.fft(sv,axis=2).squeeze(-1)
                    filt = (tl-numz)//2 
                    fty[:,:,tl//2-filt:tl//2+filt] = 0 + 0j
                    left = torch.fft.ifft(1j*tkk2*fty,axis=2).real
                    trunc = tl//prob
                    valid_loss = torch.mean(((left-right)[:,:,trunc:-trunc])**2)
                                    
                    print('Iter {:04d} | '.format(epoch) + 'Train Loss {:.6f}; Valid Loss {:.6f}'.format(loss.item(), valid_loss.item()))
                    Va_loss.append(valid_loss.item())
                    
                    if valid_loss < best_loss:
                        best_loss = valid_loss
                        torch.save(sepi, './model/'+args.model_ind+'.pkl')
                    
        Trdf = pd.DataFrame(Tr_loss)
        Vadf = pd.DataFrame(Va_loss)
        Trdf.to_csv("./results/Tr_loss_dr.csv")
        Vadf.to_csv("./results/Va_loss_dr.csv")
    
    def test(self):
        args = self.args
        X_test = self.X_test; y_test = self.y_test
        xs = self.xs.cpu(); ts = self.ts.cpu()
        lx = self.lx
        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        
        # Data Preprocessing
        L = args.L
        N = X_test.shape[0]
        
        u0 = torch.sin(2*np.pi*xs/L).unsqueeze(0)
        kk = torch.cat((torch.arange(0, len(xs)/2),torch.tensor([0]),\
                        torch.arange(-len(xs)/2+1, 0)))*2*np.pi/L
        kk[len(xs)//(2*lx):-len(xs)//(2*lx)+1] = 0

        k = torch.tensor(kk).unsqueeze(0)
        gend = genc(k,u0)
        
        # node
        u00 = torch.sin(2*np.pi*xs/args.L).unsqueeze(-1)
        node = NODE(u00,kk)
        node.mlp = sepi.mlp.cpu()
        
        errors = torch.zeros(N)
        with torch.no_grad():
            for Ni in range(N):
                csp_test = CSpline(ts, X_test[Ni])
                node.csp = csp_test
                gend.fit_data = csp_test
                
                loss = 0
                for j in range(len(ts)-1):
                    t = ts[j]; x = y_test[Ni,:,j].unsqueeze(0)
                    pred = node.forward(t, x.t())
                    true = gend.forward(t, x).t()
                    loss += torch.mean((pred-true)**2)
                errors[Ni] = loss/len(ts)
        
        pd_er = pd.DataFrame(errors.numpy())
        pd_er.to_csv("./results/"+args.model_ind+"errors.csv")    
        print('Testing error:',torch.mean(errors))
        
        
        
        
        
        
        