import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from utilities4 import *

import operator
from functools import reduce
from functools import partial

from timeit import default_timer

torch.manual_seed(0)
np.random.seed(0)

def conv(in_planes, output_channels, kernel_size, stride, dropout_rate):
    return nn.Sequential(
        nn.Conv2d(in_planes, output_channels, kernel_size=kernel_size,
                  stride=stride, padding=(kernel_size - 1) // 2, bias=False),
        nn.BatchNorm2d(output_channels),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Dropout(dropout_rate)
    )


def deconv(input_channels, output_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(input_channels, output_channels, kernel_size=4,
                           stride=2, padding=1),
        nn.LeakyReLU(0.1, inplace=True)
    )


def output_layer(input_channels, output_channels, kernel_size, stride, dropout_rate):
    return nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size,
                     stride=stride, padding=(kernel_size - 1) // 2)


class U_net(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, dropout_rate):
        super(U_net, self).__init__()
        self.input_channels = input_channels
        self.conv1 = conv(input_channels, 64, kernel_size=kernel_size, stride=2, dropout_rate=dropout_rate)
        self.conv2 = conv(64, 128, kernel_size=kernel_size, stride=2, dropout_rate=dropout_rate)
        self.conv3 = conv(128, 256, kernel_size=kernel_size, stride=2, dropout_rate=dropout_rate)
        self.conv3_1 = conv(256, 256, kernel_size=kernel_size, stride=1, dropout_rate=dropout_rate)
        self.conv4 = conv(256, 512, kernel_size=kernel_size, stride=2, dropout_rate=dropout_rate)
        self.conv4_1 = conv(512, 512, kernel_size=kernel_size, stride=1, dropout_rate=dropout_rate)
        self.conv5 = conv(512, 1024, kernel_size=kernel_size, stride=2, dropout_rate=dropout_rate)
        self.conv5_1 = conv(1024, 1024, kernel_size=kernel_size, stride=1, dropout_rate=dropout_rate)

        self.deconv4 = deconv(1024, 256)
        self.deconv3 = deconv(768, 128)
        self.deconv2 = deconv(384, 64)
        self.deconv1 = deconv(192, 32)
        self.deconv0 = deconv(96, 16)

        self.output_layer = output_layer(16 + input_channels, output_channels,
                                         kernel_size=kernel_size, stride=1, dropout_rate=dropout_rate)

    def forward(self, x):
        out_conv1 = self.conv1(x)
        out_conv2 = self.conv2(out_conv1)
        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_conv4 = self.conv4_1(self.conv4(out_conv3))
        out_conv5 = self.conv5_1(self.conv5(out_conv4))

        out_deconv4 = self.deconv4(out_conv5)
        concat4 = torch.cat((out_conv4, out_deconv4), 1)
        out_deconv3 = self.deconv3(concat4)
        concat3 = torch.cat((out_conv3, out_deconv3), 1)
        out_deconv2 = self.deconv2(concat3)
        concat2 = torch.cat((out_conv2, out_deconv2), 1)
        out_deconv1 = self.deconv1(concat2)
        concat1 = torch.cat((out_conv1, out_deconv1), 1)
        out_deconv0 = self.deconv0(concat1)
        concat0 = torch.cat((x, out_deconv0), 1)
        out = self.output_layer(concat0)
        return out


class Net2d(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Net2d, self).__init__()
        self.net = U_net(input_channels=in_dim, output_channels=out_dim, kernel_size=3, dropout_rate=0)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = self.net(x)
        x = x.permute(0, 2, 3, 1)
        return x

    def count_params(self):
        c = 0
        for p in self.parameters():
            c += reduce(operator.mul, list(p.size()))
        return c




ntrain = 180
ntest = 20

in_dim = 1
out_dim = 1

batch_size = 50

epochs = 200
learning_rate = 0.001
scheduler_step = 40
scheduler_gamma = 0.5

loss_k = 2
loss_group = True

print(epochs, learning_rate, scheduler_step, scheduler_gamma)

path = 'KF_unet_N'+str(ntrain)+'_k' + str(loss_k)+'_g' + str(loss_group)+'_ep' + str(epochs)
path_model = 'model/'+path
path_train_err = 'results/'+path+'train.txt'
path_test_err = 'results/'+path+'test.txt'
path_image = 'image/'+path





sub = 1
S = 64

T_in = 100
T = 400
T_out = T_in+T
step = 1


t1 = default_timer()

data = np.load('../graph-pde/data/KFvorticity_Re40_N200_T500.npy')
data = torch.tensor(data, dtype=torch.float)[..., ::sub, ::sub]
print(data.shape )

train_a = data[:ntrain,T_in-1:T_out-1].reshape(ntrain*T, S, S)
train_u = data[:ntrain,T_in:T_out].reshape(ntrain*T, S, S)

test_a = data[-ntest:,T_in-1:T_out-1].reshape(ntest*T, S, S)
test_u = data[-ntest:,T_in:T_out].reshape(ntest*T, S, S)

print(train_a.shape)
print(train_u.shape)
assert (S == train_u.shape[2])



train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False)

t2 = default_timer()

print('preprocessing finished, time used:', t2-t1)
device = torch.device('cuda')

model = Net2d(in_dim, out_dim).cuda()
# model = torch.load('model/KF_vel_N20_ep200_m12_w32')

print(model.count_params())
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)


lploss = LpLoss(size_average=False)
h1loss = HsLoss(k=1, group=False, size_average=False)
h2loss = HsLoss(k=2, group=False, size_average=False)
myloss = HsLoss(k=loss_k, group=loss_group, size_average=False)

'''

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x = x.to(device).view(batch_size, S, S, in_dim)
        y = y.to(device).view(batch_size, S, S, out_dim)

        out = model(x).reshape(batch_size, S, S, out_dim)
        loss = myloss(out, y)
        train_l2 += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    test_l2 = 0
    test_h1 = 0
    test_h2 = 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device).view(batch_size, S, S, in_dim)
            y = y.to(device).view(batch_size, S, S, out_dim)

            out = model(x).reshape(batch_size, S, S, out_dim)
            test_l2 += lploss(out, y).item()
            test_h1 += h1loss(out, y).item()
            test_h2 += h2loss(out, y).item()


    t2 = default_timer()
    scheduler.step()
    print(ep, t2 - t1, train_l2/(ntrain*T), test_l2/(ntest*T), test_h1/(ntest*T), test_h2/(ntest*T) )

torch.save(model, path_model)


#model.eval()

model = torch.load('model/KF_unet_N180_k0_gTrue_ep200').cuda()
test_a = test_a[0,:,:]

T = 1000
pred = torch.zeros(S,S,T)
out = test_a.reshape(1,S,S).cuda()
with torch.no_grad():
    for i in range(T):
        out = model(out.reshape(1,S,S,in_dim))
        pred[:,:,i] = out.view(S,S)

print("complete")
scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()})

'''
model = torch.load('model/KF_unet500_N900_k0_gTrue_ep50').cuda()


# test_a = np.load('../graph-pde/data/KFvorticity_N1_Re40_T1000.npy')[100]
test_a = np.load('../graph-pde/data/KFvorticity_Re500_N25_part1.npy')[-1,100,::4,::4]
print(test_a.shape)
test_a = torch.tensor(test_a, dtype=torch.float).reshape(-1,64,64,1).cuda()

input = test_a[0,:,:].reshape(1,S,S,1)
# test_a = w_to_u(test_a)
T = 10000
pred = torch.zeros(S,S,T,in_dim)
out = input.reshape(1,S,S,in_dim)
with torch.no_grad():
    for i in range(T):
        out = model(out.reshape(1,S,S,in_dim))
        pred[:,:,i] = out.view(S,S,in_dim)

        if i% 100 == 0:
            print(i, torch.mean(out))
print("complete")
print(torch.mean(pred))
# scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()})

# scipy.io.savemat('pred/KF_w_fourier500_T10000_k2_gTrue_ep50_m20_w64.mat', mdict={'pred': pred.cpu().numpy()})
scipy.io.savemat('pred/KF_w_unet500_T10000_k0_gTrue_ep50_m20_w64.mat', mdict={'pred': pred.cpu().numpy()})
