
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from utilities3 import *

import operator
from functools import reduce
from functools import partial

from timeit import default_timer

from Adam import Adam

torch.manual_seed(0)
np.random.seed(0)


################################################################
# 3d fourier layers
################################################################
pi = np.pi
T = 50

def compl_mul3d(a, b):
    # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
    op = partial(torch.einsum, "bixyz,ioxyz->boxyz")
    return op(a, b)


class SpectralConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        super(SpectralConv3d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.modes3 = modes3*2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfftn(x, dim=[2,3,4], norm="ortho")

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(2), x.size(3), x.size(4)//2 + 1, device=x.device, dtype=torch.cfloat)
        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
            compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
            compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
            compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
            compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)

        #Return to physical space
        x = torch.fft.irfftn(out_ft, s=(x.size(2), x.size(3), x.size(4)), dim=[2,3,4], norm="ortho")
        return x

class FNO3d(nn.Module):
    def __init__(self, modes1, modes2, modes3, width):
        super(FNO3d, self).__init__()

        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.width2 = width*4
        self.in_dim = 4
        self.out_dim = 1
        self.padding = 10  # pad the domain if input is non-periodic
        
        self.fc0 = nn.Linear(self.in_dim, self.width)

        self.conv0 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv1 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv2 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv3 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.w0 = nn.Conv3d(self.width, self.width, 1)
        self.w1 = nn.Conv3d(self.width, self.width, 1)
        self.w2 = nn.Conv3d(self.width, self.width, 1)
        self.w3 = nn.Conv3d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, self.width2)
        self.fc2 = nn.Linear(self.width2, self.out_dim)

    def forward(self, x, gradient=False):
        self.T_only = True

        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 4, 1, 2, 3)
        x = F.pad(x, [0, self.padding])  # pad the domain if input is non-periodic

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        if gradient:
            DX = self.dX(x)

        x = x[..., :-self.padding]
        x = x.permute(0, 2, 3, 4, 1)  # pad the domain if input is non-periodic
        x1 = self.fc1(x)
        x = torch.tanh(x1)
        x = self.fc2(x)

        if gradient:
            DX = self.dQ(x1, DX)
            return x, DX

        return x

    def get_grid(self, shape, device):
        batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
        gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
        return torch.cat((gridx, gridy, gridz), dim=-1).to(device)

    def dX(self, w):
        Lx = 1
        Ly = 1
        Lt = (50+10)

        batchsize = w.size(0)
        channel = w.size(1)
        nx = w.size(-3)
        ny = w.size(-2)
        nt = w.size(-1)

        device = w.device
        w = w.reshape(batchsize, channel, nx, ny, nt)

        w_h = torch.fft.fftn(w, dim=[-3, -2, -1])

        k_t = torch.cat((torch.arange(start=0, end=nt // 2, step=1, device=device),
                         torch.arange(start=-nt // 2, end=0, step=1, device=device)),
                        0).reshape(1, 1, nt).repeat(nx, ny, 1).reshape(1, 1, nx, ny, nt)
        wt_h = 1j * k_t * w_h * (2 * pi / Lt)
        wt = torch.fft.irfftn(wt_h[..., :nt // 2 + 1], dim=[-3, -2, -1])[..., :-self.padding]
        if self.T_only:
            return (0, 0, wt, 0, 0)


        k_x = torch.cat((torch.arange(start=0, end=nx // 2, step=1, device=device),
                         torch.arange(start=-nx // 2, end=0, step=1, device=device)),
                        0).reshape(nx, 1, 1).repeat(1, ny, nt).reshape(1, 1, nx, ny, nt)
        k_y = torch.cat((torch.arange(start=0, end=ny // 2, step=1, device=device),
                         torch.arange(start=-ny // 2, end=0, step=1, device=device)),
                        0).reshape(1, ny, 1).repeat(nx, 1, nt).reshape(1, 1, nx, ny, nt)
        wx_h = 1j * k_x * w_h * (2 * pi / Lx)
        wy_h = 1j * k_y * w_h * (2 * pi / Ly)
        wxx_h = 1j * k_x * wx_h * (2 * pi / Lx)
        wyy_h = 1j * k_y * wy_h * (2 * pi / Ly)

        wx = torch.fft.irfftn(wx_h[..., :nt // 2 + 1], dim=[-3, -2, -1])[..., :-self.padding]
        wy = torch.fft.irfftn(wy_h[..., :nt // 2 + 1], dim=[-3, -2, -1])[..., :-self.padding]
        wxx = torch.fft.irfftn(wxx_h[..., :nt // 2 + 1], dim=[-3, -2, -1])[..., :-self.padding]
        wyy = torch.fft.irfftn(wyy_h[..., :nt // 2 + 1], dim=[-3, -2, -1])[..., :-self.padding]

        return (wx, wy, wt, wxx, wyy)

    def dQ(self, X1, DX):
        # X1 (batch, x,y,t,m)
        X1 = X1.permute(0, 4, 1, 2, 3)  # (b,m,x,y,t)
        (wx, wy, wt, wxx, wyy) = DX  # DX (batch, i, x,y,t)

        b = X1.shape[0]
        x = X1.shape[-3]
        y = X1.shape[-2]
        t = X1.shape[-1]
        i = self.width
        m = self.width2

        ### Gradient: D(f o g) = D(f(g)) o Dg
        DW1 = self.fc1.weight #(m, i)
        Dtanh = 1/torch.cosh(X1)**2  # (b,m,x,y,t)

        DW2 = self.fc2.weight.reshape(m,1)  # (o, n)
        ## DQ = torch.einsum("mi,bmxyz,m->bixyz", DW1, Dtanh, DW2)
        DW12 = DW1 * DW2
        DQ = torch.einsum("mi,bmxyz->bixyz", DW12, Dtanh)

        wtQ = torch.einsum("bixyz,bixyz->bxyz", DQ, wt)
        if self.T_only:
            return (0, 0, wtQ, 0, 0)

        wxQ = torch.einsum("bixyz,bixyz->bxyz", DQ, wx)
        wyQ = torch.einsum("bixyz,bixyz->bxyz", DQ, wy)

        # ### Hessian: D^2(f o g) = Dg o Hf o Dg + Df o Hg
        Htanh = -2*Dtanh*torch.tanh(X1)
        H2 = DW2.reshape(1,m,1,1,1)*Htanh # (b,m,x,y,t)

        wxx1 = torch.einsum("bixyz,mi,bmxyz,mj,bjxyz->bxyz", wx,DW1,H2,DW1,wx) # (b,x,y,t)
        # wxx1 = torch.einsum("bixyz,mi->bmxyz", wx, DW1)
        # wxx1 = torch.einsum("bmxyz,bmxyz->bxyz", wxx1**2, H2)
        wxx2 = torch.einsum("bixyz,bixyz->bxyz", DQ.reshape(b,i,x,y,t), wxx)
        wxxQ = wxx1 + wxx2
        wyy1 = torch.einsum("bixyz,mi,bmxyz,mj,bjxyz->bxyz", wy,DW1,H2,DW1,wy) # (b,x,y,t)
        # wyy1 = torch.einsum("bixyz,mi->bmxyz", wy, DW1)
        # wyy1 = torch.einsum("bmxyz,bmxyz->bxyz", wyy1**2, H2)
        wyy2 = torch.einsum("bixyz,bixyz->bxyz", DQ.reshape(b,i,x,y,t), wyy)
        wyyQ = wyy1 + wyy2
        return (wxQ, wyQ, wtQ, wxxQ, wyyQ)
        # return wtQ
    
    
    

################################################################
# configs
################################################################

pretrain = False
finetune = not pretrain

TRAIN_PATH = '../data/ns_V1e-3_N5000_T50.mat'
TEST_PATH = '../data/ns_V1e-3_N5000_T50.mat'

ntrain = 4800
ntest = 200
batch_size = 10

if not pretrain:
    ntest = 1
    batch_size = 1

modes = 8
width = 20


epochs = 200
learning_rate = 0.001
scheduler_step = 40
scheduler_gamma = 0.5

print(epochs, learning_rate, scheduler_step, scheduler_gamma)

path = 'test'
# path = 'ns_fourier_V100_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width)
path_model = 'model/' + path
path_train_err = 'results/' + path + 'train.txt'
path_test_err = 'results/' + path + 'test.txt'
path_image = 'image/' + path

runtime = np.zeros(2, )
t1 = default_timer()

sub = 1
S = 64 // sub
T_in = 1
T = 50

################################################################
# load data
################################################################

reader = MatReader(TRAIN_PATH)
data = reader.read_field('u')

train_a = data[:ntrain, ::sub, ::sub, :T_in]
train_u = data[:ntrain, ::sub, ::sub, :]

# test_a = data[-ntest:, ::sub, ::sub, :T_in]
# test_u = data[-ntest:, ::sub, ::sub, :]
a = 5
test_a = data[-ntest-a:-ntest-a+1, ::sub, ::sub, :T_in]
test_u = data[-ntest-a:-ntest-a+1, ::sub, ::sub, :]

print(train_u.shape)
print(test_u.shape)
assert (S == train_u.shape[-2])
assert (T == train_u.shape[-1])

a_normalizer = UnitGaussianNormalizer(train_a)
train_a = a_normalizer.encode(train_a)
test_a = a_normalizer.encode(test_a)

train_a = train_a.reshape(ntrain, S, S, 1, T_in).repeat([1, 1, 1, T, 1])
test_a = test_a.reshape(ntest, S, S, 1, T_in).repeat([1, 1, 1, T, 1])

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size,
                                          shuffle=False)

t2 = default_timer()

print('preprocessing finished, time used:', t2 - t1)
device = torch.device('cuda')

def get_forcing(S):
    x1 = torch.tensor(np.linspace(0, 1, S+1)[:-1], dtype=torch.float).reshape(S, 1).repeat(1, S)
    x2 = torch.tensor(np.linspace(0, 1, S+1)[:-1], dtype=torch.float).reshape(1, S).repeat(S, 1)
    return 0.1 * (torch.cos(2*np.pi*(x1 + x2)) + torch.sin(2*np.pi*(x1 + x2))).reshape(1,S,S,1).cuda()

forcing = get_forcing(S)

def FDM_NS_vorticity(w, Dw, v=1/1000):
    Lx = 1
    Ly = 1

    batchsize = w.size(0)
    nx = w.size(1)
    ny = w.size(2)
    nt = w.size(3)
    device = w.device
    w = w.reshape(batchsize, nx, ny, nt)
    (wxQ, wyQ, wtQ, wxxQ, wyyQ) = Dw
    # wtQ = Dw

    w_h = torch.fft.fft2(w, dim=[1, 2])
    # Wavenumbers in y-direction
    k_max = nx//2
    N = nx
    k_x = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device),
                     torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(N, 1).repeat(1, N).reshape(1,N,N,1)
    k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device),
                     torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(1, N).repeat(N, 1).reshape(1,N,N,1)
    # Negative Laplacian in Fourier space
    lap = (k_x ** 2 + k_y ** 2) * (2*np.pi/Lx)**2
    lap[0, 0, 0, 0] = 1.0
    f_h = w_h / lap

    ux_h = 1j * 2*np.pi /Lx * k_y * f_h
    uy_h = -1j * 2*np.pi /Ly * k_x * f_h
    wx_h = 1j * 2*np.pi /Lx * k_x * w_h
    wy_h = 1j * 2*np.pi /Ly * k_y * w_h
    wlap_h = -lap * w_h

    ux = torch.fft.irfft2(ux_h[:, :, :k_max + 1], dim=[1, 2])
    uy = torch.fft.irfft2(uy_h[:, :, :k_max + 1], dim=[1, 2])
    wx = torch.fft.irfft2(wx_h[:, :, :k_max+1], dim=[1,2])
    wy = torch.fft.irfft2(wy_h[:, :, :k_max+1], dim=[1,2])
    wlap = torch.fft.irfft2(wlap_h[:, :, :k_max+1], dim=[1,2])

    # dt = 50/(nt+1)
    # wt = (w[:, :, :, 2:] - w[:, :, :, :-2]) / (2 * dt)
    # wlapQ = wxxQ + wyyQ
    # Du = wtQ + (ux*wxQ + uy*wyQ - v*wlapQ)#[...,1:-1] #- forcing
    # return Du

    Du1 = wtQ + (ux*wx + uy*wy - v*wlap)#[...,1:-1] #- forcing
    return Du1

mask = torch.tensor(np.linspace(0, 1, T, endpoint=True), dtype=torch.float, device='cuda').reshape(1, 1, 1, T)
mask = mask + 0.5
def PINO_loss(u, forcing, Dw=None):
    batchsize = u.size(0)
    nx = u.size(1)
    ny = u.size(2)
    nt = u.size(3)

    u = u.reshape(batchsize, nx, ny, nt)
    lploss = LpLoss(size_average=True)

    Du = FDM_NS_vorticity(u, Dw)
    f = forcing.repeat(batch_size, 1, 1, nt)
    loss_f = lploss(Du*mask, f*mask) # weight up early time
    return loss_f


################################################################
# training and evaluation
################################################################

myloss = LpLoss(size_average=False)
if pretrain:
    # model = FNO3d(modes, modes, modes, width).cuda()
    # model = torch.load('model/FNO3d').cuda()
    model = torch.load('../model/FNO3d-adjoint').cuda()
    print(count_params(model))

    # model.eval()
    # test_l2 = 0.0
    # test_l2_T = 0.0
    # with torch.no_grad():
    #     for x, y in test_loader:
    #         x, y = x.cuda(), y.cuda()
    #
    #         out = model(x).view(batch_size, S, S, T)
    #         test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
    #         test_l2_T += myloss(out[..., -1], y.view(batch_size, S, S, T)[..., -1]).item()
    # test_l2 /= ntest
    # test_l2_T /= ntest
    # t2 = default_timer()
    # print(t2 - t1, test_l2, test_l2_T)


    optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)


    for ep in range(epochs):
        model.train()
        t1 = default_timer()
        train_f = 0
        train_l2 = 0
        for x, y in train_loader:
            x, y = x.cuda(), y.cuda()

            optimizer.zero_grad()
            out, Dw = model(x, gradient=True)
            out = out.reshape(batch_size, S, S, T)

            loss = myloss(out, y)
            loss_ic = myloss(out[..., 0], y[..., 0])
            loss_f = PINO_loss(out, forcing, Dw)
            loss_pino = loss + loss_f*0.2
            loss_pino.backward()

            optimizer.step()
            # train_f += loss_f.item()
            train_l2 += loss.item()

        scheduler.step()

        model.eval()
        test_l2 = 0.0
        test_l2_T = 0.0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.cuda(), y.cuda()

                out = model(x).view(batch_size, S, S, T)
                test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
                test_l2_T += myloss(out[..., -1], y.view(batch_size, S, S, T)[..., -1]).item()

        train_f /= ntrain
        train_l2 /= ntrain
        test_l2 /= ntest
        test_l2_T /= ntest

        t2 = default_timer()
        print(ep, t2 - t1, train_f, train_l2, test_l2, test_l2_T)

    torch.save(model, '../model/FNO3d-adjoint1')



if finetune:
    # model = torch.load('../model/FNO3d-adjoint0').cuda()
    model = torch.load('../model/FNO3d-adjoint').cuda()
    # model = FNO3d(modes, modes, modes, width).cuda()
    print(count_params(model))

    optimizer = Adam(model.parameters(), lr=5e-4, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=scheduler_gamma)

    myloss = LpLoss(size_average=False)

    model.train()

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()
            y = y.reshape(batch_size, S, S, T)
            out_anchor = model(x, gradient=False)
            print(x.shape)
            print(y.shape)

        # print('start')
        # for i in range(100):
        #     t1 = default_timer()
        #     out = model(x, gradient=False)
        #     t2 = default_timer()
        #     print(i, (t2 - t1)/100)

    for ep in range(10000):
        t1 = default_timer()
        optimizer.zero_grad()

        out, Dw = model(x, gradient=True)
        out = out.reshape(batch_size,S,S,T)

        loss = myloss(out, y)
        loss_ic = myloss(out[..., 0], y[..., 0])
        loss_f = PINO_loss(out, forcing, Dw)
        loss_anchor = myloss(out, out_anchor)
        loss_pino = loss_ic + loss_f*0.2 + loss_anchor*0.1
        loss_pino.backward()

        optimizer.step()
        scheduler.step()
        test_l2 = loss.item()
        test_ic = loss_ic.item()
        test_f = loss_f.item()
        test_l2_T = myloss(out[..., -1], y[..., -1]).item()

        t2 = default_timer()
        print(ep, t2-t1, test_f, test_ic, test_l2, test_l2_T)

    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
