import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import torchdiffeq as ode
import functools
import seaborn as sns
import torch.nn.utils as utils
from utils.cubic_spline import CSpline
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")

class Sepi(nn.Module):  # myModel
    def __init__(self, device):
        super(Sepi, self).__init__()
        self.device = device
        input_size = 6
        hidden_size = 64
        output_size = 1
        self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, output_size, bias=True))
        
        for m in self.mlp.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
    
    def forward(self, s_batch):
        mlp = self.mlp
        FF = mlp(s_batch)
        return(FF)

class NODE(nn.Module):  # myModel
    def __init__(self,u0,kk):
        super(NODE, self).__init__()
        self.csp = None
        self.mlp = None
        self.u0 = u0
        self.kks = kk.unsqueeze(-1)
    
    def forward(self, t, s_b):
        mlp = self.mlp; csp = self.csp; u0 = self.u0; kks = self.kks
        ut = csp.fit(t).expand(u0.shape)
        u_b = u0+ut
        s_tildex = torch.fft.fft(s_b,axis=0)
        sxs = []
        for i in range(5):
            sxs.append(torch.fft.ifft((1j*kks)**i*s_tildex,axis=0).real)
        s_batch = torch.cat((u_b,sxs[0],sxs[1],sxs[2],sxs[3],sxs[4]),axis=-1)
        FF = mlp(s_batch)
        return(FF)

class Model():
    def __init__(self, args):
        self.args = args
        
        self.y_train,self.X_train,self.y_valid,self.X_valid,self.y_test,self.X_test,\
            self.xs,self.ts = self.read_data()
        
        if args.pretr:
            self.sepi = torch.load('./model/'+args.model_ind+'.pkl')
        else:
            self.sepi = Sepi(args.device).to(args.device)
    
    def read_data(self):
        args = self.args
        lt = 1
        device = args.device
        filename = args.filename
        X_train = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        X_valid = pd.read_csv("./dataset/data_"+filename+"/u_valid.csv").values[:,1:] 
        X_test = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        y_train = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        y_valid = pd.read_csv("./dataset/data_"+filename+"/s_valid.csv").values[:,1:]
        y_test = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        xs = pd.read_csv("./dataset/data_"+filename+"/x.csv").values[:,1]
        ts = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1] 
        
        #noise = args.sigma_sd * np.random.randn(y_train.shape[0], y_train.shape[1]) + 0
        noise = args.sigma_sd*np.abs(y_train).mean() * np.random.randn(y_train.shape[0], y_train.shape[1]) + 0
        y_train = y_train + noise
        y_train = y_train.reshape(-1,len(xs),len(ts))[:args.num]
        y_valid = y_valid.reshape(-1,len(xs),len(ts))
        y_test = y_test.reshape(-1,len(xs),len(ts))
        X_train = torch.tensor(X_train).to(device)[:args.num]
        y_train = torch.tensor(y_train)[:,:,::lt].to(device)
        X_valid = torch.tensor(X_valid).to(device)
        y_valid = torch.tensor(y_valid)[:,:,::lt].to(device)
        X_test = torch.tensor(X_test)
        y_test = torch.tensor(y_test)[:,:,::lt]
        xs = torch.tensor(xs).to(device)
        ts = torch.tensor(ts).to(device)
        
        ts = ts[::lt]
        return(y_train,X_train,y_valid,X_valid,y_test,X_test,xs,ts)
    
    def train(self):
        args = self.args
        epochs = args.epochs
        batch_size = args.batch_size
        device = args.device
        L = args.L
        
        X_train = self.X_train; y_train = self.y_train
        X_valid = self.X_valid; y_valid = self.y_valid
        xs = self.xs; ts = self.ts
        sepi = self.sepi
        
        params = sepi.parameters()
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        sepi.train() 
        
        # Data Preprocessing
        N = X_train.shape[0]
        u0 = torch.sin(2*np.pi*xs/args.L).repeat(N,len(ts)).reshape(N,len(ts),-1).permute(0,2,1).to(device)
        ut = X_train.repeat(1,len(xs)).reshape(N,len(xs),-1)
        u = (ut+u0).unsqueeze(-1)
        s = y_train.unsqueeze(-1)
        
        kk = torch.cat((torch.arange(0, len(xs)/2),torch.tensor([0]),\
                        torch.arange(-len(xs)/2+1, 0)))*2*np.pi/L
        kk = kk.to(device)
        # train
        tle = args.tle
        numz = args.numz 
        prob = args.prob
        best_loss = torch.tensor(1000)
        
        Tr_loss = []
        Va_loss = []
        for epoch in range(epochs):
            optimizer.zero_grad()
            sn = torch.randint(0,N,(batch_size,)).to(device)
            st = torch.randint(0,len(ts)-tle,(1,)).to(device)
            u_b = u[sn, :, st:st+tle]
            s_b = s[sn, :, st:st+tle]
            
            s_tildex = torch.fft.fft(s_b,axis=1)
            kks = kk.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(s_b.shape)
            sxs = []
            for i in range(5):
                sxs.append(torch.fft.ifft((1j*kks)**i*s_tildex,axis=1).real)
            s_batch = torch.cat((u_b,sxs[0],sxs[1],sxs[2],sxs[3],sxs[4]),axis=-1)
            right = sepi(s_batch).squeeze(-1)
            
            tL = ts[tle-1]*(tle+1)/tle            
            tkk = (torch.cat((torch.arange(0,tle/2+1), torch.arange(-tle/2+1, 0)), axis=0)).to(device) * 2*torch.pi/tL
            tkk2 = tkk.unsqueeze(0).unsqueeze(0).expand(batch_size,len(xs),tle)  
            fty = torch.fft.fft(s_b,axis=2).squeeze(-1)
            filt = (tle-numz)//2 
            fty[:,:,tle//2-filt:tle//2+filt] = 0 + 0j
            left = torch.fft.ifft(1j*tkk2*fty,axis=2).real
            
            trunc = tle//prob
            loss = torch.mean(((left-right)[:,:,trunc:-trunc])**2)
            
            loss.backward()
            optimizer.step()   
            Tr_loss.append(loss.item())
            
            if epoch%args.outime == 0:
                with torch.no_grad():
                    Nv = X_valid.shape[0]
                    u0 = torch.sin(2*np.pi*xs/args.L).repeat(Nv,len(ts)).reshape(Nv,len(ts),-1).permute(0,2,1).to(device)
                    ut = X_valid.unsqueeze(1).expand(Nv,len(xs),len(ts))
                    uv = (ut+u0).unsqueeze(-1)
                    sv = y_valid.unsqueeze(-1)
                    s_tildex = torch.fft.fft(sv,axis=1)
                    kks = kk.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(sv.shape)
                    sxs = []
                    for i in range(5):
                        sxs.append(torch.fft.ifft((1j*kks)**i*s_tildex,axis=1).real)
                    s_batch = torch.cat((uv,sxs[0],sxs[1],sxs[2],sxs[3],sxs[4]),axis=-1)
                    right = sepi(s_batch).squeeze(-1)
                    
                    tl = len(ts)
                    tL = ts[tl-1]*(tl+1)/tl            
                    tkk = (torch.cat((torch.arange(0,tl/2+1), torch.arange(-tl/2+1, 0)), axis=0)).to(device) * 2*torch.pi/tL
                    tkk2 = tkk.unsqueeze(0).unsqueeze(0).expand(Nv,len(xs),tl)  
                    fty = torch.fft.fft(sv,axis=2).squeeze(-1)
                    filt = (tl-numz)//2 
                    fty[:,:,tl//2-filt:tl//2+filt] = 0 + 0j
                    left = torch.fft.ifft(1j*tkk2*fty,axis=2).real
                    trunc = tl//prob
                    valid_loss = torch.mean(((left-right)[:,:,trunc:-trunc])**2)
                                    
                    print('Iter {:04d} | '.format(epoch) + 'Train Loss {:.6f}; Valid Loss {:.6f}'.format(loss.item(), valid_loss.item()))
                    Va_loss.append(valid_loss.item())
                    
                    if valid_loss < best_loss:
                        best_loss = valid_loss
                        torch.save(sepi, './model/'+args.model_ind+'.pkl')
            
        Trdf = pd.DataFrame(Tr_loss)
        Vadf = pd.DataFrame(Va_loss)
        Trdf.to_csv("./results/Tr_loss_ks.csv")
        Vadf.to_csv("./results/Va_loss_ks.csv")
        
    def test(self):
        args = self.args
        device = args.device 
        L = args.L
        X_test = self.X_test; y_test = self.y_test
        xs = self.xs.cpu(); ts = self.ts.cpu()

        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        
        # Data Preprocessing
        N = X_test.shape[0]
        s = y_test.unsqueeze(-1)
        kk = torch.cat((torch.arange(0, len(xs)/2),torch.tensor([0]),\
                        torch.arange(-len(xs)/2+1, 0)))*2*np.pi/L
        
        # Data Preprocessing
        prob = args.prob
        
        errors = torch.zeros(N)
        with torch.no_grad():
            for Ni in range(N):
                ni = torch.tensor([Ni])
                s_b = s[ni]
                
                s_tildex = torch.fft.fft(s_b,axis=1)
                kks = kk.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(s_b.shape)
                sxs = []
                for i in range(5):
                    sxs.append(torch.fft.ifft((1j*kks)**i*s_tildex,axis=1).real)
                
                tle = len(ts)         
                trunc = tle//prob
                
                # node
                u00 = torch.sin(2*np.pi*xs/args.L).unsqueeze(-1)
                node = NODE(u00,kk)
                node.mlp = sepi.mlp.cpu()
                
                s0 = s_b[0,:,trunc:trunc+1,0]
                csp_test = CSpline(ts, X_test[ni][0])
                node.csp = csp_test
                
                print('node',str(Ni),'...')
                preds =torch.zeros(len(ts),len(xs))
                start = time.time()
                #
                pred = ode.odeint(node, s0, ts[trunc:], rtol=args.rtol, atol=args.atol, method=args.method)
                preds[trunc:] = pred[:,:,0]
                #
                end = time.time()  
                error = (preds-s_b[0].squeeze(-1).t())[trunc:].abs().mean()
                errors[Ni] = error
                print("test time=",round(end-start,2), 'error=', error)
                
                # draw
                with plt.ioff():
                    xx_mesh, tt_mesh = np.meshgrid(xs, ts[trunc:-trunc])
                    fig, ax = plt.subplots(3, 2, figsize=(40, 13))
                    fig.set_tight_layout(True)
                    s3 = s_b[0].squeeze(-1).t()
                    smin = s3.min(); smax = s3.max(); ls=30
                    # first subgraph
                    cax0 = ax[0,1].pcolormesh(tt_mesh, xx_mesh, s3[trunc:-trunc], vmin=smin, vmax=smax)
                    fig.colorbar(cax0, ax=ax[0,1])
                    ax[0,1].tick_params(labelsize=ls); ax[0,1].set_ylabel(r"$x$",size=ls+15)
                    # second subgraph
                    cax1 = ax[1,1].pcolormesh(tt_mesh, xx_mesh, preds[trunc:-trunc], vmin=smin, vmax=smax)
                    fig.colorbar(cax1, ax=ax[1,1])
                    ax[1,1].tick_params(labelsize=ls); ax[1,1].set_ylabel(r"$x$",size=ls+15)
                    # third subgraph
                    cax2 = ax[2,1].pcolormesh(tt_mesh, xx_mesh, (preds-s3)[trunc:-trunc], vmin=0)
                    fig.colorbar(cax2, ax=ax[2,1])
                    ax[2,1].tick_params(labelsize=ls); ax[2,1].set_ylabel(r"$x$",size=ls+15)
                    ax[2,1].set_xlabel(r"$t$",size=ls+15)
                    # control parameter
                    ax[2,0].plot(ts, X_test[Ni])
                    ax[2,0].tick_params(labelsize=ls); ax[2,0].set_ylabel(r"$u_6$",size=ls+15)

                    plt.savefig("figures/"+str(Ni)+".pdf")
        
        pd_er = pd.DataFrame(errors.numpy())
        pd_er.to_csv("./results/"+args.model_ind+"errors.csv")    
        print('Testing error:',torch.mean(errors))
        