import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import torchdiffeq as ode
import functools
from utils.cubic_spline import CSpline
from torch.utils.data import DataLoader

class Sepi(nn.Module):  # myModel
    def __init__(self, input_size, hidden_size, output_size):
        super(Sepi, self).__init__()
        '''
        self.mlp = nn.Sequential(nn.Linear(input_size, 128), nn.ReLU(), nn.Linear(128, 64), 
                                 nn.ReLU(), nn.Linear(64, 32), nn.ReLU(),
                                 nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, output_size))
        '''
        self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, output_size, bias=True))
        
        for m in self.mlp.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
        
    def forward(self, t, x):
        guy = self.mlp(x)
        return(guy)

class FODE(nn.Module):  # myModel
    def __init__(self):
        super(FODE, self).__init__()
        self.mlp = None
        self.fit_data = None 
    
    def forward(self, t, x):
        p = self.fit_data.fit(t)
        px = torch.cat((x,p),axis=1)
        guy = self.mlp(px)
        return(guy)

class TODE(nn.Module):
    def __init__(self, device):
        super(TODE, self).__init__()
        self.fit_data = None 
        self.true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)

    def forward(self, t, s):
        g =  torch.mm(s**3, self.true_A) + self.fit_data.fit(t)
        return g
    
class Model():
    def __init__(self, args):
        self.args = args
        self.adjoint = args.adjoint
        self.method = args.method
        
        self.y_train,self.X_train,self.y_valid,self.X_valid,self.y_test,self.X_test,self.ts = self.read_data()
        
        self.sepi = Sepi(args.N_dim, args.hidden_size, args.N_dim-1).to(args.device)
    
    def read_data(self):
        args = self.args
        device = args.device
        filename = args.filename
        X_train = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        X_valid = pd.read_csv("./dataset/data_"+filename+"/u_valid.csv").values[:,1:]
        X_test = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        ts = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1]
        y_train = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        y_valid = pd.read_csv("./dataset/data_"+filename+"/s_valid.csv").values[:,1:]
        y_test = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        
        noise = args.sigma_sd*np.abs(y_train).mean() * np.random.randn(y_train.shape[0],y_train.shape[1]) + 0
        y_train = y_train + noise
        
        N = X_train.shape[0]
        X_train = torch.tensor(X_train).to(device)[:args.num][:,::args.inte]
        y_train = torch.tensor(y_train).to(device).reshape(N,len(ts),-1)[:args.num][:,::args.inte]
        X_valid = torch.tensor(X_valid).to(device)[:,::args.inte]
        y_valid = torch.tensor(y_valid).to(device).reshape(X_valid.shape[0],len(ts),-1)[:,::args.inte]
        X_test = torch.tensor(X_test).to(device)[:,::args.inte]
        y_test = torch.tensor(y_test).to(device).reshape(X_test.shape[0],len(ts),-1)[:,::args.inte]
        ts = torch.tensor(ts).to(device)[::args.inte]
    
        return(y_train,X_train,y_valid,X_valid,y_test,X_test,ts)
    
    def train(self):
        args = self.args
        epochs = args.epochs
        batch_size = args.batch_size
        device = args.device
        X_train = self.X_train; y_train = self.y_train
        X_valid = self.X_valid; y_valid = self.y_valid
        ts = self.ts
        
        sepi = self.sepi
        params = sepi.parameters()
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        
        N = X_train.shape[0]
        sepi.train() 
        
        tle = args.tle
        numz = args.numz 
        prob = args.prob
        N_dim = args.N_dim
        best_loss = 1e8; ds = len(ts)
        dt = ts[-1]/(len(ts)-1)
        for epoch in range(epochs): 
            optimizer.zero_grad()
            sn = torch.randint(0,N,(batch_size,)).to(device)
            st = torch.randint(0,len(ts)-tle,(1,)).to(device)
            u_batch = X_train[sn,st:st+tle].unsqueeze(-1)
            s_batch = y_train[sn,st:st+tle]
            trunc = tle//prob
            ty = torch.cat((s_batch,u_batch),axis=-1) 
            left = torch.zeros_like(ty)
            for i in range(tle-2):
                left[:,i+1] = (ty[:,i+2]-ty[:,i])/(2*dt)
            right = sepi.mlp(ty)            
            loss = torch.mean(((left[:,:,:N_dim-1]-right)[:,trunc:-trunc])**2) 
            
            loss.backward()
            optimizer.step()
            if epoch % args.outime == 0:
                trunc = ds//prob
                ty = torch.cat((y_valid,X_valid.unsqueeze(-1)),axis=-1) 
                left = torch.zeros_like(ty)
                for i in range(ds-2):
                    left[:,i+1] = (ty[:,i+2]-ty[:,i])/(2*dt)                
                right = sepi.mlp(ty)            
                valid_loss = torch.mean(((left[:,:,:N_dim-1]-right)[:,trunc:-trunc])**2) 
                print('Iter {:04d} | '.format(epoch) + 'Train Loss {:.6f}; Valid Loss {:.6f}'.format(loss.item(), valid_loss.item()))
                if valid_loss < best_loss:
                    torch.save(sepi, './model/'+args.model_ind+'.pkl')
                    valid_loss = best_loss
        
    def test(self):
        args = self.args
        device = args.device
        X_test = self.X_test; y_test = self.y_test
        ts = self.ts
        
        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        fode = FODE()
        fode.mlp = sepi.mlp
        tode = TODE(device)
        
        X = X_test; Y = y_test; dt = ts[-1]/(len(ts)-1)
        ds = Y.shape[1]; N_dim = args.N_dim; numz=args.numz
        trunc = ds//args.prob
        preds = torch.zeros(X.shape[0],ds-trunc,N_dim-1)
        Te_loss = 0
        errors = np.zeros(X.shape[0]) 
        with torch.no_grad():
            for i in range(X.shape[0]):
                csp = CSpline(ts, X[i])
                fode.fit_data = csp
                tode.fit_data = csp
                s0 = Y[i,trunc:trunc+1]
                pred = ode.odeint(fode, s0, ts[trunc:], rtol=args.rtol, 
                                  atol=args.atol, method='dopri5')[:,0,:]
                preds[i] = pred
                loss = torch.mean((pred[:-trunc]-Y[i,trunc:-trunc])**2)
                print(str(i)+"test_loss:",loss)
                errors[i] = loss.item()
                Te_loss += loss
        
        print("Test_loss:", Te_loss/X.shape[0])
        pd_er = pd.DataFrame(errors)
        pd_er.to_csv("./results/"+args.model_ind+"errors.csv")
        
        
    def fleft(self,ty,numz):
        device = self.args.device; ts = self.ts
        data_size = ty.shape[0]; N_dim = ty.shape[1]
        L = ts[data_size-1]*(data_size+1)/data_size
        kk = (torch.cat((torch.arange(0,data_size/2+1), torch.arange(-data_size/2+1, 0)), axis=0)).to(device) * 2*torch.pi/L
        kk2 = kk.unsqueeze(-1).expand(data_size,N_dim)
        fty = torch.fft.fft(ty,axis=0)
        filt = (data_size-numz)//2
        fty[data_size//2-filt:data_size//2+filt,:] = 0 + 0j
        left = torch.fft.ifft(1j*kk2*fty,axis=0).real
        return(left)
    
    def ftraj(self,ty,tst,ted,numz):
        device = self.args.device; ts = self.ts
        data_size = ty.shape[0]
        L = ts[data_size-1]*(data_size+1)/data_size
        kk = (torch.cat((torch.arange(0,data_size/2+1), torch.arange(-data_size/2+1, 0)), axis=0)).to(device) * 2*torch.pi/L
        fty = torch.fft.fft(ty,axis=0)
        filt = (data_size-numz)//2
        fty[data_size//2-filt:data_size//2+filt,:] = 0 + 0j
        numt = 2000
        tss = torch.linspace(tst,ted,numt)
        traj = torch.zeros(numt,2)
        for i in range(numt):
            tm = tss[i]
            for j in range(fty.shape[0]):
                k1 = kk[j]
                traj[i] = traj[i] + (fty[j]*torch.exp(1j*k1*tm)).real
        traj = traj/ty.shape[0]
        return(traj,tss)
    
    def draw(self):
        args = self.args; device = args.device
        X_test = self.X_test; y_test = self.y_test; ts = self.ts
        fleft = self.fleft; ftraj = self.ftraj
        X_valid = self.X_valid; y_valid = self.y_valid
        
        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        fode = FODE()
        fode.mlp = sepi.mlp
        tode = TODE(device)
        
        X = X_valid; Y = y_valid
        data_size = Y.shape[1]; N_dim = args.N_dim; numz=args.numz
        trunc = data_size//args.prob
        
        st2 = 120
        ed2 = st2+60
        with torch.no_grad():
            with plt.ioff():
                i = 4
                csp = CSpline(ts, X[i])
                fode.fit_data = csp
                tode.fit_data = csp
                
                vefi = tode(ts, Y[i])
                font1 = {'family':'Times New Roman', 'weight':'normal','size':30}
                fig, ax = plt.subplots(1, 3, figsize=(40,10))
                
                ax[0].plot(vefi.cpu()[st2:ed2,0], vefi.cpu()[st2:ed2,1], 'ko-', label='True', linewidth=2, markersize=4)
                ks = [0,4,8,12]
                numzs = [256,256,256,256]
                cs = ['y','b','g','m']
                for j in range(len(ks)):
                    k = ks[j]; numz = numzs[j]; c = cs[j]
                    ty = torch.cat((Y[i,k:],X[i,k:].unsqueeze(-1)),axis=-1)
                    left = fleft(ty,numz)    
                    ax[0].plot(left.cpu()[st2-k:ed2-k,0], left.cpu()[st2-k:ed2-k,1], c+'x', label='Fourier', markersize=10)
                ax[0].tick_params(labelsize=30)
                ax[0].legend(prop=font1)
                ax[0].set_xlabel(r'$s_1$',size=25)
                ax[0].set_ylabel(r'$s_2$',size=25)
                
                s0 = Y[i,trunc:trunc+1]
                pred = ode.odeint(fode, s0, ts[trunc:], rtol=args.rtol, 
                                  atol=args.atol, method='dopri5')[:,0,:]
                ax[1].plot(Y[i,trunc:,0].cpu(),Y[i,trunc:,1].cpu(),'bo', label='True')
                ax[1].plot(pred[:,0].cpu(),pred[:,1].cpu(),'r--', linewidth=3, label='Prediction')
                ax[1].tick_params(labelsize=30)
                ax[1].legend(prop=font1)
                ax[1].set_xlabel(r'$s_1$',size=25)
                ax[1].set_ylabel(r'$s_2$',size=25)
                
                plt.savefig("figures/"+str(i)+".pdf")
                
                
                
                
                