import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import torchdiffeq as ode
import functools
import seaborn as sns
import torch.nn.utils as utils
from utils.cubic_spline import CSpline
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")

class NODE(nn.Module):  # myModel
    def __init__(self,u0,xs):
        super(NODE, self).__init__()
        input_size = 2*len(xs)
        hidden_size = 64
        output_size = len(xs)
        self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, output_size, bias=True))
        
        for m in self.mlp.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
            
        self.csp = None
        self.u0 = u0
    
    def forward(self, t, s_b):
        mlp = self.mlp; csp = self.csp; u0 = self.u0; kks = self.kks
        ut = csp.fit(t).expand(u0.shape)
        u_b = u0 + ut
        s_tildex = torch.fft.fft(s_b,axis=0)
        sxs = []
        for i in range(5):
            sxs.append(torch.fft.ifft((1j*kks)**i*s_tildex,axis=0).real)
        s_batch = torch.cat((u_b,sxs[0],sxs[1],sxs[2],sxs[3],sxs[4]),axis=-1)
        FF = mlp(s_batch)
        return(FF)

class Model():
    def __init__(self, args):
        args.device = 'cpu'
        self.args = args
        self.y_train,self.X_train,self.y_valid,self.X_valid,self.y_test,self.X_test,\
            self.xs,self.ts,self.csp_train,self.csp_valid = self.read_data()
        
        xs=self.xs
        L = args.L
        kk = torch.cat((torch.arange(0, len(xs)/2),torch.tensor([0]),\
                        torch.arange(-len(xs)/2+1, 0)))*2*np.pi/L
        kk = kk.to(args.device)
        u00 = torch.sin(2*np.pi*xs/args.L).unsqueeze(-1)
        self.sepi = NODE(u00,kk).to(args.device)
    
    def read_data(self):
        args = self.args
        lt = 1
        device = args.device
        filename = args.filename
        X_train = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        X_valid = pd.read_csv("./dataset/data_"+filename+"/u_valid.csv").values[:,1:] 
        X_test = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        y_train = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        y_valid = pd.read_csv("./dataset/data_"+filename+"/s_valid.csv").values[:,1:]
        y_test = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        xs = pd.read_csv("./dataset/data_"+filename+"/x.csv").values[:,1]
        ts = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1] 
        
        #noise = args.sigma_sd * np.random.randn(y_train.shape[0], y_train.shape[1]) + 0
        noise = args.sigma_sd*np.abs(y_train).mean() * np.random.randn(y_train.shape[0], y_train.shape[1]) + 0
        y_train = y_train + noise
        y_train = y_train.reshape(-1,len(xs),len(ts))[:args.num]
        y_valid = y_valid.reshape(-1,len(xs),len(ts))
        y_test = y_test.reshape(-1,len(xs),len(ts))
        X_train = torch.tensor(X_train).to(device)[:args.num]
        y_train = torch.tensor(y_train)[:,:,::lt].to(device)
        X_valid = torch.tensor(X_valid).to(device)
        y_valid = torch.tensor(y_valid)[:,:,::lt].to(device)
        X_test = torch.tensor(X_test).to(device)
        y_test = torch.tensor(y_test)[:,:,::lt].to(device)
        xs = torch.tensor(xs).to(device)
        ts = torch.tensor(ts).to(device)
        ts = ts[::lt]
        
        csp_train = []
        for i in range(X_train.shape[0]):
            csp_train.append(CSpline(ts, X_train[i]))
        csp_valid = []
        for i in range(X_valid.shape[0]):
            csp_valid.append(CSpline(ts, X_valid[i]))
        
        return(y_train,X_train,y_valid,X_valid,y_test,X_test,xs,ts,csp_train,csp_valid)


    def train(self):
        args = self.args
        epochs = args.epochs
        batch_size = args.batch_size
        device = args.device
        L = args.L
        
        X_train = self.X_train; y_train = self.y_train
        X_valid = self.X_valid; y_valid = self.y_valid
        csp_train = self.csp_train; csp_valid = self.csp_valid
        xs = self.xs; ts = self.ts
        sepi = self.sepi
        
        params = sepi.parameters()
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        sepi.train() 
        
        # train
        tle = 2
        method = 'euler'
        numz = args.numz 
        prob = args.prob
        N_tr = X_train.shape[0]
        N_va = X_valid.shape[0]
        best_loss = torch.tensor(1000)
        
        Tr_loss = []
        Va_loss = []
        for epoch in range(epochs):
            optimizer.zero_grad()
            loss = torch.tensor(0.0).to(device)
            sn = torch.randint(0,N_tr,(batch_size,)).to(device)
            st = torch.randint(0,len(ts)-tle,(1,)).to(device)
            for sni in sn:
                sepi.csp = csp_train[sni]
                s0 = y_train[sni,:,st:st+1]
                pred = ode.odeint(sepi, s0, ts[st:st+tle], rtol=args.rtol, atol=args.atol, method=method)[:,:,0]
                loss = loss + torch.sum((y_train[sni,:,st:st+tle].t()-pred)**2)
            
            loss.backward()
            optimizer.step()   
            Tr_loss.append(loss.item())
            
            if epoch%args.outime == 0:
                    
                with torch.no_grad():
                    st = 0
                    sn = torch.arange(N_va)
                    valid_loss = torch.tensor(0.0).to(device)
                    for sni in sn:
                        sepi.csp = csp_valid[sni]
                        s0 = y_valid[sni,:,st:st+1]
                        pred = ode.odeint(sepi, s0, ts[st:st+tle], rtol=args.rtol, atol=args.atol, method=method)[:,:,0]
                        valid_loss = valid_loss + torch.mean((y_valid[sni,:,st:st+tle].t()-pred)**2)                    
                        
                    print('Iter {:04d} | '.format(epoch) + 'Train Loss {:.6f}; Valid Loss {:.6f}'.format(loss.item(), valid_loss.item()))
                    Va_loss.append(valid_loss.item())
                    
                    if valid_loss < best_loss:
                        best_loss = valid_loss
                        torch.save(sepi, './model/'+args.model_ind+'.pkl')
                        
        Trdf = pd.DataFrame(Tr_loss)
        Vadf = pd.DataFrame(Va_loss)
        Trdf.to_csv("./results/Tr_loss_dr.csv")
        Vadf.to_csv("./results/Va_loss_dr.csv")
        
    def test(self):
        args = self.args
        L = args.L
        X_test = self.X_test; y_test = self.y_test
        xs = self.xs; ts = self.ts
        trunc = len(ts)//args.prob

        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        N_te = X_test.shape[0]
        errors = torch.zeros(N_te)
        with torch.no_grad():
            for Ni in range(N_te):
                csp_test = CSpline(ts, X_test[Ni])
                sepi.csp = csp_test
                s0 = y_test[Ni,:,trunc:trunc+1]
                pred = ode.odeint(sepi, s0, ts[trunc:], rtol=args.rtol, atol=args.atol, method=args.method)[:,:,0]
                loss = (y_test[Ni].t()[trunc:]-pred).abs().mean()
                errors[Ni] = loss.item()
                
                print('error=', errors[Ni])
        
        pd_er = pd.DataFrame(errors.numpy())
        pd_er.to_csv("./results/"+args.model_ind+"errors.csv")    
        print('Testing error:',torch.mean(errors))
        