import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import functools
import seaborn as sns
from torch.utils.data import DataLoader

################################################################
#  1d fourier layer
################################################################
class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):
        super(SpectralConv1d, self).__init__()

        """
        1D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))
    
    # Complex multiplication
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)
    
    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1,  device=x.device, dtype=torch.cfloat)
        out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)

        #Return to physical space
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x

class MLP(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(MLP, self).__init__()
        self.mlp1 = nn.Conv1d(in_channels, mid_channels, 1)
        self.mlp2 = nn.Conv1d(mid_channels, out_channels, 1)
    
    def forward(self, x):
        x = self.mlp1(x)
        x = F.gelu(x)
        x = self.mlp2(x)
        return x

class FNO1d(nn.Module):
    def __init__(self, modes, width):
        super(FNO1d, self).__init__()
        
        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """
        
        self.modes1 = modes
        self.width = width
        self.padding = 8 # pad the domain if input is non-periodic
        
        self.p = nn.Linear(3, self.width) # input channel_dim is 2: (u0(x), x)
        self.conv0 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv1 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv2 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv3 = SpectralConv1d(self.width, self.width, self.modes1)
        self.mlp0 = MLP(self.width, self.width, self.width)
        self.mlp1 = MLP(self.width, self.width, self.width)
        self.mlp2 = MLP(self.width, self.width, self.width)
        self.mlp3 = MLP(self.width, self.width, self.width)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)
        self.q = MLP(self.width, 1, self.width*2)  # output channel_dim is 1: u1(x)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.p(x)
        x = x.permute(0, 2, 1)
        # x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic

        x1 = self.conv0(x)
        x1 = self.mlp0(x1)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x1 = self.mlp1(x1)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x1 = self.mlp2(x1)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x1 = self.mlp3(x1)
        x2 = self.w3(x)
        x = x1 + x2
        
        # x = x[..., :-self.padding] # pad the domain if input is non-periodic
        x = self.q(x)
        x = x.permute(0, 2, 1)
        return x
    
    def get_grid(self, shape, device):
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    
    
class Model():
    def __init__(self, args):
        self.args = args
        
        X1,p1,X2,p2,X3,p3,x,t = self.read_data()
        self.X2, self.p2, self.x, self.t = X2, p2, x, t
        
        X_train = torch.cat((X1.unsqueeze(-1),p1.unsqueeze(-1)),axis=-1)[:,:,:-1].permute(0,2,1,3)
        X_train = X_train.reshape(-1,len(x),2)
        y_train = X1.unsqueeze(-1)[:,:,1:].permute(0,2,1,3).reshape(-1,len(x),1)
        
        X_valid = torch.cat((X2.unsqueeze(-1),p2.unsqueeze(-1)),axis=-1)[:,:,:-1].permute(0,2,1,3)
        X_valid = X_valid.reshape(-1,len(x),2)
        y_valid = X2.unsqueeze(-1)[:,:,1:].permute(0,2,1,3).reshape(-1,len(x),1)
        
        X_test = torch.cat((X3.unsqueeze(-1),p3.unsqueeze(-1)),axis=-1).permute(0,2,1,3)
        y_test = X3.unsqueeze(-1).permute(0,2,1,3)
        print('X_train',X_train.shape)
        print('y_train',y_train.shape)
        print('X_valid',X_valid.shape)
        print('y_valid',y_valid.shape)
        print('X_test',X_test.shape)
        print('y_test',y_test.shape)
        self.X_train=X_train; self.y_train=y_train; self.X_test=X_test; self.y_test=y_test
        self.X_valid=X_valid; self.y_valid=y_valid
        
        modes = 9
        width = 16
        self.sepi = FNO1d(modes, width).to(args.device)
    
    def read_data(self):
        args = self.args
        lt = 5
        filename = args.filename; num = args.num
        p1 = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        p2 = pd.read_csv("./dataset/data_"+filename+"/u_valid.csv").values[:,1:] 
        p3 = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        X1 = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        X2 = pd.read_csv("./dataset/data_"+filename+"/s_valid.csv").values[:,1:]
        X3 = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        x = pd.read_csv("./dataset/data_"+filename+"/x.csv").values[:,1]
        t = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1]
        X1=X1.reshape(-1,len(x),len(t)); X2=X2.reshape(-1,len(x),len(t)); X3=X3.reshape(-1,len(x),len(t))
        X1=torch.tensor(X1[:num]).to(torch.float32); p1=torch.tensor(p1[:num]).to(torch.float32)
        X2=torch.tensor(X2[:num]).to(torch.float32); p2=torch.tensor(p2[:num]).to(torch.float32)
        X3=torch.tensor(X3[:num]).to(torch.float32); p3=torch.tensor(p3[:num]).to(torch.float32)
        x = torch.tensor(x).to(torch.float32); t = torch.tensor(t).to(torch.float32)
        
        p1 = p1.repeat(len(x),1,1).permute(1,0,2)
        p2 = p2.repeat(len(x),1,1).permute(1,0,2)
        p3 = p3.repeat(len(x),1,1).permute(1,0,2)
        u0 = torch.sin(2*np.pi*x/args.L)
        u1 = u0.repeat(X1.shape[0],len(t),1).permute(0,2,1)
        u2 = u0.repeat(X2.shape[0],len(t),1).permute(0,2,1)
        u3 = u0.repeat(X3.shape[0],len(t),1).permute(0,2,1)
        p1 = p1 + u1; p2 = p2 + u2; p3 = p3 + u3;
        return(X1[...,::lt],p1[...,::lt],X2[...,::lt],p2[...,::lt],X3[...,::lt],p3[...,::lt],x,t[::lt])
    
    def train(self):
        args = self.args;
        device = args.device; batch_size = args.batch_size
        X_train = self.X_train; y_train = self.y_train 
        X_valid = self.X_valid.to(device); y_valid = self.y_valid.to(device)
        ts = self.t
        sepi = self.sepi.to(device)
            
        train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
            
        optimizer = torch.optim.Adam(sepi.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        for ep in range(args.epochs):
            sepi.train()
            train_l2 = 0
            for x, y in train_loader:
                x = x.to(device); y = y.to(device)
                optimizer.zero_grad()
                out = sepi(x)
                
                mse = torch.sum((out-y)**2)
                wei = torch.tensor(0.0)
                for param in sepi.parameters():
                    wei = wei + 1e-5*torch.norm(param, p=2)
                mse = mse + wei
                mse.backward() 

                optimizer.step()
                train_l2 += mse.item()
            
            #print(ep, train_l2)
            if ep%args.outime==0:
                sepi.eval()
                with torch.no_grad():
                    pred = sepi(X_valid)
                    test_l2 = torch.sum((y_valid-pred)**2)/y_valid.shape[0]
                    print("####################")
                    print(ep, test_l2)
                    print("####################")
                    torch.save(sepi, './model/'+str(args.filename)+'_FNO.pkl')
    
    def test(self):
        args = self.args
        X_train = self.X_train; y_train = self.y_train 
        X_test = self.X_test; y_test = self.y_test; ts = self.t
        sepi = self.sepi
        device = args.device
        
        with torch.no_grad():
            sepi = torch.load('./model/'+str(args.filename)+'_FNO.pkl')
            with torch.no_grad():
                preds = torch.zeros_like(y_test)
                preds[:,0] = y_test[:,0]
                x0 = X_test[:,0].to(device)
                
                for i in range(1,len(ts)):
                    pred = sepi(x0)
                    preds[:,i] = pred
                    x0 = torch.cat((pred,X_test[:,i,:,1:2].to(device)),axis=-1)
            
            print("test_loss:",torch.mean((preds-y_test).abs()))
        




