import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import torchdiffeq as ode
import functools
from utils.cubic_spline import CSpline
from torch.utils.data import DataLoader

class Sepi(nn.Module):  # myModel
    def __init__(self, input_size, hidden_size, output_size):
        super(Sepi, self).__init__()
        '''
        self.mlp = nn.Sequential(nn.Linear(input_size, 128), nn.ReLU(), nn.Linear(128, 64), 
                                 nn.ReLU(), nn.Linear(64, 32), nn.ReLU(),
                                 nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, output_size))
        '''
        self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, output_size, bias=True))
        
        for m in self.mlp.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
        
    def forward(self, t, x):
        guy = self.mlp(x)
        return(guy)

class FODE(nn.Module):  # myModel
    def __init__(self):
        super(FODE, self).__init__()
        self.mlp = None
        self.fit_data = None 
    
    def forward(self, t, x):
        p = self.fit_data.fit(t)
        px = torch.cat((x,p),axis=1)
        guy = self.mlp(px)
        return(guy)
    
class Model():
    def __init__(self, args):
        self.args = args
        self.adjoint = args.adjoint
        self.method = args.method
        
        self.y_train,self.X_train,self.y_valid,self.X_valid,self.y_test,self.X_test,self.ts = self.read_data()
        
        self.sepi = Sepi(args.N_dim, args.hidden_size, args.N_dim-1).to(args.device)
    
    def read_data(self):
        args = self.args
        device = args.device
        filename = args.filename
        X_train = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        X_valid = pd.read_csv("./dataset/data_"+filename+"/u_valid.csv").values[:,1:]
        X_test = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        ts = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1]
        y_train = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        y_valid = pd.read_csv("./dataset/data_"+filename+"/s_valid.csv").values[:,1:]
        y_test = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        
        noise = args.sigma_sd*np.abs(y_train).mean() * np.random.randn(y_train.shape[0],y_train.shape[1]) + 0
        y_train = y_train + noise
        
        N = X_train.shape[0]
        X_train = torch.tensor(X_train).to(device)[:args.num][:,::args.inte]
        y_train = torch.tensor(y_train).to(device).reshape(N,len(ts),-1)[:args.num][:,::args.inte]
        X_valid = torch.tensor(X_valid).to(device)[:,::args.inte]
        y_valid = torch.tensor(y_valid).to(device).reshape(X_valid.shape[0],len(ts),-1)[:,::args.inte]
        X_test = torch.tensor(X_test).to(device)[:,::args.inte]
        y_test = torch.tensor(y_test).to(device).reshape(X_test.shape[0],len(ts),-1)[:,::args.inte]
        ts = torch.tensor(ts).to(device)[::args.inte]
    
        return(y_train,X_train,y_valid,X_valid,y_test,X_test,ts)
    
    def train(self):
        args = self.args
        epochs = args.epochs
        batch_size = args.batch_size
        device = args.device
        X_train = self.X_train; y_train = self.y_train
        X_valid = self.X_valid; y_valid = self.y_valid
        ts = self.ts
        
        sepi = self.sepi
        params = sepi.parameters()
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        
        N = X_train.shape[0]
        sepi.train() 
        
        tle = args.tle
        numz = args.numz 
        prob = args.prob
        N_dim = args.N_dim
        best_loss = 1e8; ds = len(ts)
        for epoch in range(epochs): 
            optimizer.zero_grad()
            sn = torch.randint(0,N,(batch_size,)).to(device)
            st = torch.randint(0,len(ts)-tle,(1,)).to(device)
            u_batch = X_train[sn,st:st+tle].unsqueeze(-1)
            s_batch = y_train[sn,st:st+tle]
            trunc = tle//prob
            L = (ts[st+tle-1]-ts[st])*(tle+1)/tle
            kk = (torch.cat((torch.arange(0,tle/2+1), torch.arange(-tle/2+1, 0)), axis=0)).to(device) * 2*torch.pi/L
            kk2 = kk.unsqueeze(0).unsqueeze(-1).expand(batch_size,tle,N_dim)  
            
            ty = torch.cat((s_batch,u_batch),axis=-1) 
            fty = torch.fft.fft(ty,axis=1)
            filt = (tle-numz)//2
            fty[:,tle//2-filt:tle//2+filt,:] = 0 + 0j
            left = torch.fft.ifft(1j*kk2*fty,axis=1).real
            #right = sepi.forward(1., ty)
            right = sepi.mlp(ty)            
            loss = torch.mean(((left[:,:,:N_dim-1]-right)[:,trunc:-trunc])**2) 
            
            loss.backward()
            optimizer.step()
            if epoch % args.outime == 0:
                with torch.no_grad():
                    trunc = ds//prob
                    L = ts[-1]*(ds+1)/ds
                    kk = (torch.cat((torch.arange(0,ds/2+1), torch.arange(-ds/2+1, 0)), axis=0)).to(device) * 2*torch.pi/L
                    kk2 = kk.unsqueeze(0).unsqueeze(-1).expand(X_valid.shape[0],ds,N_dim)  
                    ty = torch.cat((y_valid,X_valid.unsqueeze(-1)),axis=-1) 
                    fty = torch.fft.fft(ty,axis=1)
                    filt = (tle-numz)//2
                    fty[:,tle//2-filt:tle//2+filt,:] = 0 + 0j
                    left = torch.fft.ifft(1j*kk2*fty,axis=1).real
                    right = sepi.mlp(ty)            
                    valid_loss = torch.mean(((left[:,:,:N_dim-1]-right)[:,trunc:-trunc])**2) 
                    print('Iter {:04d} | '.format(epoch) + 'Train Loss {:.6f}; Valid Loss {:.6f}'.format(loss.item(), valid_loss.item()))
                    if valid_loss < best_loss:
                        torch.save(sepi, './model/'+args.model_ind+'.pkl')
                        valid_loss = best_loss
        
    def test(self):
        args = self.args
        device = args.device
        X_test = self.X_test; y_test = self.y_test
        ts = self.ts
        
        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        fode = FODE()
        fode.mlp = sepi.mlp
        
        X = X_test; Y = y_test
        data_size = Y.shape[1]; N_dim = args.N_dim
        trunc = data_size//args.prob
        preds = torch.zeros(X.shape[0],data_size-trunc,N_dim-1)
        Te_loss = 0
        errors = np.zeros(X.shape[0]) 
        with torch.no_grad():
            for i in range(X.shape[0]):
                csp = CSpline(ts, X[i])
                fode.fit_data = csp
                s0 = Y[i,trunc:trunc+1]
                pred = ode.odeint(fode, s0, ts[trunc:], rtol=args.rtol, 
                                  atol=args.atol, method='dopri5')[:,0,:]
                preds[i] = pred
                loss = torch.mean((pred[:-trunc]-Y[i,trunc:-trunc])**2)
                print(str(i)+"_test_loss:",loss)
                errors[i] = loss.item()
                Te_loss += loss
        
        print("Test_loss:", Te_loss/X.shape[0])
        pd_er = pd.DataFrame(errors)
        pd_er.to_csv("./results/"+args.model_ind+"errors.csv")
    
    def draw(self):
        args = self.args; device = args.device
        X_test = self.X_test; y_test = self.y_test; ts = self.ts
        X_valid = self.X_valid; y_valid = self.y_valid
        
        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        fode = FODE()
        fode.mlp = sepi.mlp
        
        X = X_test; Y = y_test
        data_size = Y.shape[1]; N_dim = args.N_dim; numz=args.numz
        trunc = data_size//args.prob
        
        with torch.no_grad():
            with plt.ioff():
                #for i in range(X.shape[0]):
                i=1
                csp = CSpline(ts, X[i])
                fode.fit_data = csp    
                
                s0 = Y[i,trunc:trunc+1]
                pred = ode.odeint(fode, s0, ts[trunc:], rtol=args.rtol, 
                                  atol=args.atol, method='dopri5')[:,0,:]
                
                font1 = {'weight':'normal','size':30}
                fig = plt.figure(figsize=(40,6))    
                
                true = Y[i,trunc:].cpu().numpy(); pre = pred.cpu().numpy()
                '''
                ax1 = fig.add_subplot(121, projection='3d')
                ax1.plot(true[:,0],true[:,1],true[:,2],'bo', label='True')
                ax1.plot(pre[:,0],pre[:,1],pre[:,2],'r--', linewidth=3, label='Prediction')
                ax1.tick_params(labelsize=30)
                ax1.legend(prop=font1)
                ax1.set_xlabel(r'$x$',size=35)
                ax1.set_ylabel(r'$y$',size=35)
                ax1.set_zlabel(r'$z$',size=35)
                '''
                
                ax2 = fig.add_subplot(131)
                ax2.plot(ts[trunc:].cpu(),true[:,0],'bo', label='True')
                ax2.plot(ts[trunc:].cpu(),pre[:,0],'r--', linewidth=3, label='Prediction')
                ax2.tick_params(labelsize=30)
                ax2.legend(prop=font1)
                ax2.set_xlabel(r'$x$',size=35)
                
                ax2 = fig.add_subplot(132)
                ax2.plot(ts[trunc:].cpu(),true[:,1],'bo', label='True')
                ax2.plot(ts[trunc:].cpu(),pre[:,1],'r--', linewidth=3, label='Prediction')
                ax2.tick_params(labelsize=30)
                ax2.legend(prop=font1)
                ax2.set_xlabel(r'$y$',size=35)
         
                ax2 = fig.add_subplot(133)
                ax2.plot(ts[trunc:].cpu(),true[:,2],'bo', label='True')
                ax2.plot(ts[trunc:].cpu(),pre[:,2],'r--', linewidth=3, label='Prediction')
                ax2.tick_params(labelsize=30)
                ax2.legend(prop=font1)
                ax2.set_xlabel(r'$z$',size=35)

                plt.savefig("figures/"+str(i)+".pdf")
                loss = torch.mean((pred[:-trunc]-Y[i,trunc:-trunc])**2)
                print(str(i)+"_test_loss:",loss)
                
                    
        
            
        
        
        