import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import torchdiffeq as ode
import functools
from utils.cubic_spline import CSpline
from torch.utils.data import DataLoader

class Sepi(nn.Module):  # myModel
    def __init__(self, input_size, hidden_size, output_size):
        super(Sepi, self).__init__()
        '''
        self.mlp = nn.Sequential(nn.Linear(input_size, 128), nn.ReLU(), nn.Linear(128, 64), 
                                 nn.ReLU(), nn.Linear(64, 32), nn.ReLU(),
                                 nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, output_size))
        '''
        self.csp = None 
        self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, output_size, bias=True))
        
        for m in self.mlp.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
    
    def forward(self, t, x):
        p = self.csp.fit(t)
        px = torch.cat((x,p),axis=1)
        guy = self.mlp(px)
        return(guy)
    
class Model():
    def __init__(self, args):
        self.args = args
        self.adjoint = args.adjoint
        self.method = args.method
        
        self.y_train, self.X_train, self.y_valid, self.X_valid, self.y_test, self.X_test, self.ts,\
            self.csp_train, self.csp_valid, self.csp_test = self.read_data()
        
        self.sepi = Sepi(args.N_dim, args.hidden_size, args.N_dim-1).to(args.device)
    
    def read_data(self):
        args = self.args
        device = args.device
        filename = args.filename
        X_train = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        X_valid = pd.read_csv("./dataset/data_"+filename+"/u_valid.csv").values[:,1:]
        X_test = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        ts = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1]
        y_train = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        y_valid = pd.read_csv("./dataset/data_"+filename+"/s_valid.csv").values[:,1:]
        y_test = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        
        noise = args.sigma_sd*np.abs(y_train).mean() * np.random.randn(y_train.shape[0],y_train.shape[1]) + 0
        y_train = y_train + noise
        
        N = X_train.shape[0]
        X_train = torch.tensor(X_train).to(device)[:args.num][:,::args.inte]
        y_train = torch.tensor(y_train).to(device).reshape(N,len(ts),-1)[:args.num][:,::args.inte]
        X_valid = torch.tensor(X_valid).to(device)[:,::args.inte]
        y_valid = torch.tensor(y_valid).to(device).reshape(X_valid.shape[0],len(ts),-1)[:,::args.inte]
        X_test = torch.tensor(X_test).to(device)[:,::args.inte]
        y_test = torch.tensor(y_test).to(device).reshape(X_test.shape[0],len(ts),-1)[:,::args.inte]
        ts = torch.tensor(ts).to(device)[::args.inte]
        
        csp_train = []
        for i in range(X_train.shape[0]):
            csp_train.append(CSpline(ts, X_train[i]))
            
        csp_valid = []
        for i in range(X_valid.shape[0]):
            csp_valid.append(CSpline(ts, X_valid[i]))
            
        csp_test = []
        for i in range(X_test.shape[0]):
            csp_test.append(CSpline(ts, X_test[i]))
        return(y_train,X_train,y_valid,X_valid,y_test,X_test,ts,csp_train,csp_valid,csp_test)
    
    def train(self):
        args = self.args
        epochs = args.epochs
        batch_size = args.batch_size
        device = args.device
        X_train = self.X_train; y_train = self.y_train; csp_train = self.csp_train
        X_valid = self.X_valid; y_valid = self.y_valid; csp_valid = self.csp_valid
        ts = self.ts
        
        sepi = self.sepi
        params = sepi.parameters()
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        
        N = X_train.shape[0]
        sepi.train() 
        
        ole = args.ole
        prob = args.prob
        N_dim = args.N_dim
        best_loss = 1e8
        
        for epoch in range(epochs): 
            optimizer.zero_grad()
            loss = torch.tensor(0.0, requires_grad = True)
            sn = torch.randint(0,N,(batch_size,)).to(device)
            st = torch.randint(0,len(ts)-ole,(1,)).to(device)
            s_batch = y_train[sn,st:st+ole] 
            for i in range(len(sn)):
                sepi.csp = csp_train[sn[i]]
                tt = ts[st:st+ole]
                s0 = s_batch[i,0:1]
                if args.adjoint:
                    pred_s = ode.odeint_adjoint(sepi, s0, tt, rtol=args.rtol, 
                                        atol=args.atol, method=args.method)[:,0]                          
                else:
                    pred_s = ode.odeint(sepi, s0, tt, rtol=args.rtol, 
                                        atol=args.atol, method=args.method)[:,0]
                loss = loss + torch.sum((pred_s-s_batch[i])**2)
            loss.backward()
            optimizer.step()
            #print('Iter {:04d} | '.format(epoch) + 'Train Loss {:.6f}'.format(loss.item()))
            
            if epoch % args.outime == 0:
                trunc = len(ts)//prob
                valid_loss = torch.tensor(0.0)
                s_batch = y_valid[:,trunc:trunc+ole] 
                for i in range(y_valid.shape[0]): 
                    sepi.csp = csp_valid[i]
                    tt = ts[trunc:trunc+ole]
                    s0 = s_batch[i,0:1]
                    if args.adjoint:
                        pred_s = ode.odeint_adjoint(sepi, s0, tt, rtol=args.rtol, 
                                            atol=args.atol, method=args.method)[:,0]                          
                    else:
                        pred_s = ode.odeint(sepi, s0, tt, rtol=args.rtol, 
                                            atol=args.atol, method=args.method)[:,0]
                    valid_loss = valid_loss + torch.mean((pred_s-s_batch[i])**2)
                
                print('Iter {:04d} | '.format(epoch) + 'Train Loss {:.6f}; Valid Loss {:.6f}'.format(loss.item(), valid_loss.item()))
                if valid_loss < best_loss:
                    torch.save(sepi, './model/'+args.model_ind+'.pkl')
                    valid_loss = best_loss
        
    def test(self):
        args = self.args
        device = args.device; prob = args.prob
        X_test = self.X_test; y_test = self.y_test; csp_test = self.csp_test
        ts = self.ts
        
        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        errors = np.zeros(X_test.shape[0]) 
        trunc = len(ts)//prob
        with torch.no_grad():
            Te_loss = torch.tensor(0.0)
            for i in range(X_test.shape[0]):
                sepi.csp = csp_test[i]
                tt = ts[trunc:]
                s0 = y_test[i,trunc:trunc+1]
                if args.adjoint:
                    pred_s = ode.odeint_adjoint(sepi, s0, tt, rtol=args.rtol, 
                                        atol=args.atol, method=args.method)[:,0]                          
                else:
                    pred_s = ode.odeint(sepi, s0, tt, rtol=args.rtol, 
                                        atol=args.atol, method=args.method)[:,0]
                loss = torch.mean((pred_s-y_test[i,trunc:])**2)
                errors[i] = loss.item()
                Te_loss = Te_loss + loss
                print(str(i)+"_Test_loss:", loss.item())
                
            print("Test_loss:", Te_loss/X_test.shape[0])
            pd_er = pd.DataFrame(errors)
            pd_er.to_csv("./results/"+args.model_ind+"errors.csv")
    
    def draw(self):
        args = self.args; device = args.device; ts = self.ts
        X_test = self.X_test; y_test = self.y_test; csp_test = self.csp_test
        X_valid = self.X_valid; y_valid = self.y_valid; csp_valid = self.csp_valid
        
        sepi = torch.load('./model/'+args.model_ind+'.pkl')
        
        X = X_test; Y = y_test; csps = csp_test
        data_size = Y.shape[1]; N_dim = args.N_dim; numz=args.numz
        trunc = data_size//args.prob
        
        with torch.no_grad():
            with plt.ioff():
                for i in range(X.shape[0]):
                    sepi.csp = csps[i]
                    
                    font1 = {'family':'Times New Roman', 'weight':'normal','size':30}
                    fig, ax = plt.subplots(1, 3, figsize=(40,10))
                    
                    s0 = Y[i,trunc:trunc+1]
                    pred = ode.odeint(sepi, s0, ts[trunc:], rtol=args.rtol, 
                                      atol=args.atol, method='dopri5')[:,0,:]
                    
                    ax[0].plot(ts[trunc:], Y[i,trunc:,0].cpu(), 'bo')
                    ax[0].plot(ts[trunc:], pred[:,0].cpu(), 'r--')
                    
                    ax[1].plot(Y[i,trunc:,0].cpu(),Y[i,trunc:,1].cpu(),'bo', label='True')
                    ax[1].plot(pred[:,0].cpu(),pred[:,1].cpu(),'r--', linewidth=3, label='Prediction')
                    ax[1].tick_params(labelsize=30)
                    ax[1].legend(prop=font1)
                    ax[1].set_xlabel(r'$s_1$',size=25)
                    ax[1].set_ylabel(r'$s_2$',size=25)
                    
                    plt.savefig("figures/"+str(i)+".png")
                