import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import torchdiffeq as ode
import functools
import seaborn as sns
import torch.nn.utils as utils
from utils.cubic_spline import CSpline
from torch.utils.data import DataLoader

class Sepi(nn.Module):  # myModel
    def __init__(self, device, u0, layers, dt):
        super(Sepi, self).__init__()
        self.device = device
        self.fit_data=None 
        self.lent = len(u0)
        self.u0 = u0
        self.dt = dt
        self.layers = layers
        self.conv1d = torch.nn.functional.conv1d
        
        q = 4
        self.q = q
        kks = [torch.tensor([1]).double()]
        kks.append(torch.tensor([-1,1]).double())
        kks.append(torch.tensor([1,-2,1]).double())
        kks.append(torch.tensor([-1,3,-3,1]).double())
        kks.append(torch.tensor([1,-4,6,-4,1]).double())
        self.kks = kks
        
        self.q = q
        bn = q+2
        self.param1 = nn.Parameter(torch.zeros(bn, requires_grad=True)).to(device)
        self.param2 = nn.Parameter(torch.zeros(bn, bn, requires_grad=True)).to(device)
        if layers == 3:
            self.param3 = nn.Parameter(torch.zeros(bn, bn, bn, requires_grad=True)).to(device)
    
    def partial(self, o, order):
        kks = self.kks; dt = self.dt; conv1d = self.conv1d; lent = self.lent
        oh = torch.hstack((o[:,lent-order:lent],o))
        m = conv1d(oh.view(1,1,-1),kks[order].view(1,1,-1))[0,0]
        return(m/(dt**order))
    
    def forward(self, t, x):
        partial = self.partial; q = self.q; fit_data = self.fit_data
        lent = self.lent; device = self.device; u0 = self.u0

        ut = fit_data.fit(t).repeat(lent,1)[:,0]
        u = ut + u0
        
        fi = torch.zeros(q+2,lent).to(device)
        fi[0] = u
        for i in range(q+1):
            fi[i+1] = partial(x,i) 
        FF = torch.einsum("k,ki->i", self.param1, fi)
        
        for i in range(q+2):
            for j in range(i+1):
                FF = FF + self.param2[i,j]*(fi[i]*fi[j])
        
        if self.layers == 3:
            for i in range(q+2):
                for j in range(i+1):
                    for k in range(j+1):
                        FF = FF + self.param3[i,j]*(fi[i]*fi[j]*fi[k])
        
        #FF = fi[0] + 0.01*fi[3] + (0.01+self.param1[0]+self.param2[0,0])*fi[1]*fi[1]
        return(FF)
    
class Model():
    def __init__(self, args):
        self.args = args
        self.adjoint = args.adjoint
        
        self.y_train,self.X_train,self.y_test,self.X_test,\
            self.xs,self.ts,self.csp_train,self.csp_test = self.read_data()
        
        xs = self.xs; ts = self.ts
        u0 = torch.sin(2*np.pi*xs/args.L)
        layers = 2
        dt = ts[1]-ts[0]
        if args.pretr:
            self.sepi = torch.load('./model/drpde_PDENET.pkl')
        else:
            self.sepi = Sepi(args.device, u0, layers, dt).to(args.device)
    
    def read_data(self):
        args = self.args
        lt = 1
        device = args.device
        filename = args.filename
        X_train = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        X_test = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        y_train = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        y_test = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        xs = pd.read_csv("./dataset/data_"+filename+"/x.csv").values[:,1]
        ts = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1] 
        
        noise = args.sigma_sd * np.random.randn(y_train.shape[0], y_train.shape[1]) + 0
        y_train = y_train + noise
        y_train = y_train.reshape(-1,len(xs),len(ts))[:args.num]
        y_test = y_test.reshape(-1,len(xs),len(ts))
        X_train = torch.tensor(X_train).to(device)[:args.num]
        y_train = torch.tensor(y_train)[:,:,::lt].to(device)
        X_test = torch.tensor(X_test).to(device)
        y_test = torch.tensor(y_test)[:,:,::lt].to(device)
        xs = torch.tensor(xs).to(device)
        ts = torch.tensor(ts).to(device)
        
        csp_train = []
        for i in range(X_train.shape[0]):
            csp_train.append(CSpline(ts, X_train[i]))
        csp_test = []
        for i in range(X_test.shape[0]):
            csp_test.append(CSpline(ts, X_test[i]))
        
        ts = ts[::lt]
        return(y_train,X_train,y_test,X_test,xs,ts,csp_train,csp_test)
            
    def train(self):
        args = self.args
        epochs = args.epochs
        batch_size = args.batch_size
        device = args.device
        adjoint = args.adjoint
        tle = args.tle1
        method = args.method1
        thr = args.thr
        X_train = self.X_train; y_train = self.y_train
        X_test = self.X_test; y_test = self.y_test
        xs = self.xs; ts = self.ts
        csp_train = self.csp_train
        csp_test = self.csp_test
        
        dots = 20
        delta_t = args.delta_t1; det = (args.delta_t2-args.delta_t1)/epochs
        sepi = self.sepi
        params = sepi.parameters()
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        
        N = X_train.shape[0]
        sepi.train() 
        best_loss = torch.tensor(1e8).to(device)
        
        Tr_losses = []
        Te_losses = []
        for epoch in range(epochs):
            delta_t = delta_t + det
            Tr_loss = torch.zeros(1).to(device)
            # Short-term 
            st = torch.randint(0,len(ts)-tle,(1,batch_size)).to(device)
            for j in range(batch_size):
                optimizer.zero_grad()
                loss = torch.tensor(0.0).to(device)
                k = st[0,j] 
                st2 = torch.randint(0,N,(args.batch_num,)).to(device)
                for i in st2:
                    y0 = y_train[i,:,k].unsqueeze(0)
                    tt = ts[k:(k+tle)]
                    true_y = y_train[i,:,k:(k+tle)]
                    sepi.fit_data = csp_train[i] 
                    pred_y = ode.odeint(sepi, y0, tt, rtol=args.rtol, atol=args.atol,\
                                method=method, options={'step_size': delta_t}).squeeze(1).t()
                        
                    loss = loss + torch.sum((pred_y-true_y)**2)
                
                # L1 regularization
                if args.l1_alpha > 0:
                    for p in sepi.parameters():
                        loss = loss + args.l1_alpha*torch.norm(p)
                
                loss.backward()

                # Prevent gradient explosion
                utils.clip_grad_norm_(sepi.parameters(), 1000)
                if epoch>args.warm_up or args.pretr:
                    sepi.param1.data[sepi.param1.data.abs()<thr] = 0
                    sepi.param2.data[sepi.param2.data.abs()<thr] = 0
                    sepi.param1.grad[sepi.param1.data.abs()<thr] = 0
                    sepi.param2.grad[sepi.param2.data.abs()<thr] = 0
                
                optimizer.step()
                Tr_loss += loss.detach()   
            
            # Adaptive learning strategies
            for dot in range(1,dots):
                if epoch == epochs//dots * dot:
                    if optimizer.param_groups[0]['lr']<1e-3:
                        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 10
                    thr = min(optimizer.param_groups[0]['lr']*0.1,1e-3)
                    if tle<args.tle2:
                        tle = tle + 1
                    print("####################")
                    print("tle->",tle,"thr->",thr,"lr->",optimizer.param_groups[0]['lr'],"method->",method)       
                    print("####################")
            
            # record loss
            print("epoch:",epoch,", st_loss:", np.round(Tr_loss.item(),4))
            if np.isnan(Tr_loss.numpy()):
                Tr_losses.append(-1)
            else:
                Tr_losses.append(Tr_loss.item())
            # testing set
            if epoch%args.outime==0 and epoch>0:
                try: 
                    Te_loss = torch.zeros(1).to(device)
                    st3 = torch.randint(0,X_test.shape[0],(args.batch_num,)).to(device)
                    with torch.no_grad():
                        for i in st3:
                            sepi.fit_data = csp_test[i]
                            s0 = y_test[i,:,0].unsqueeze(0)
                            pred_y = ode.odeint(sepi, s0, ts, rtol=args.rtol, atol=args.atol,\
                                        method=method, options={'step_size': args.delta_t}).squeeze(1).t()                          
                            Te_loss = Te_loss + torch.sum((pred_y-y_test[i])**2)
                    if np.isnan(Te_loss.numpy()):
                        Te_losses.append(-1)
                    else:
                        Te_losses.append(Te_loss.item())
                    print("####################")
                    print("epoch:",epoch,", Test_loss:", np.round(Te_loss.item()/y_test.shape[0],4))  
                    if Te_loss < best_loss:
                        torch.save(sepi, './model/drpde_PDENET_b.pkl')
                        best_loss = Te_loss
                    torch.save(sepi, './model/drpde_PDENET.pkl')
                    print("parame1: ", sepi.param1)
                    print("parame2: ", sepi.param2)
                    print("####################")
                except:
                    torch.save(sepi, './model/drpde_PDENET.pkl')
                    print("parame1: ", sepi.param1)
                    print("parame2: ", sepi.param2)
                    print("test error")

    def test(self):
        args = self.args
        device = args.device
        X_train = self.X_train; y_train = self.y_train
        X_test = self.X_test; y_test = self.y_test
        xs = self.xs; ts = self.ts
        csp_train = self.csp_train
        csp_test = self.csp_test
        
        preds_train = torch.zeros_like(y_train).to(device)
        preds_test = torch.zeros_like(y_test).to(device)
        s0 = torch.zeros(1,len(xs)).to(device)
        with torch.no_grad():
            sepi = torch.load('./model/drpde_PDENET.pkl')
            for i in range(X_train.shape[0]):
                sepi.fit_data = csp_train[i]
                s0 = y_train[i,:,0].unsqueeze(0)
                pred_y = ode.odeint(sepi, s0, ts, rtol=args.rtol, atol=args.atol,\
                            method='euler', options={'step_size': args.delta_t}).squeeze(1).t()  
                preds_train[i] = pred_y          
            
            for i in range(X_test.shape[0]):
                sepi.fit_data = csp_test[i]
                s0 = y_test[i,:,0].unsqueeze(0)
                pred_y = ode.odeint(sepi, s0, ts, rtol=args.rtol, atol=args.atol,\
                            method='euler', options={'step_size': args.delta_t}).squeeze(1).t()  
                preds_test[i] = pred_y
        print("train_loss:",torch.mean((preds_train-y_train)**2))
        print("test_loss:",torch.mean((preds_test-y_test)**2))
        



