import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils import data
from torch.autograd import Variable


import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from utilities4 import *

import operator
from functools import reduce
from functools import partial

from timeit import default_timer
import scipy.io

torch.manual_seed(0)
np.random.seed(0)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ConvLSTMCell(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()

        self.height, self.width = input_size
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size):
        return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda(),
                Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda())


class ConvLSTM(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.height, self.width = input_size

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            cell_list.append(ConvLSTMCell(input_size=(self.height, self.width),
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        # Implement stateful ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            hidden_state = self._init_hidden(batch_size=input_tensor.size(0))

        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):

            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            last_state_list.append([h, c])

        return cur_layer_input

        # if not self.return_all_layers:
        #     last_state_list = last_state_list[-1:]
        #
        # return last_state_list

    def _init_hidden(self, batch_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param


class CLSTM(nn.Module):
    def __init__(self, input_size=(64, 64), channels=1, hidden_dim=[64], num_layers=1):
        super(CLSTM, self).__init__()
        self.clstm = ConvLSTM(input_size=input_size,
                              input_dim=channels,
                              hidden_dim=hidden_dim,
                              kernel_size=(3, 3),
                              num_layers=num_layers,
                              batch_first=True,
                              bias=True,
                              return_all_layers=False)

        self.output_layer = nn.Conv2d(in_channels=hidden_dim[-1], out_channels=channels,
                                      kernel_size=3, padding=1)

    def forward(self, xx):
        batchsize = xx.shape[0]
        timelength = xx.shape[1]
        out = self.clstm(xx).reshape(batchsize*timelength,-1,64,64)
        out = self.output_layer(out).reshape(batchsize,timelength,-1,64,64)
        return out

    def count_params(self):
        c = 0
        for p in self.parameters():
            c += reduce(operator.mul, list(p.size()))
        return c




ntrain = 900
ntest = 100

batch_size = 50
batch_size2 = 50


epochs = 50
learning_rate = 0.0005
scheduler_step = 10
scheduler_gamma = 0.5

loss_k = 1
loss_group = True


print(epochs, learning_rate, scheduler_step, scheduler_gamma)

path = 'KF_w_convlstm500_N'+str(ntrain)+'_k' + str(loss_k)+'_g' + str(loss_group)+'_ep' + str(epochs)

path_model = 'model/'+path
path_train_err = 'results/'+path+'train.txt'
path_test_err = 'results/'+path+'test.txt'
path_image = 'image/'+path


runtime = np.zeros(2, )
t1 = default_timer()


sub = 1
S = 64 // sub

T_in = 100
T = 400
T_out = T_in+T
step = 5
iter = T//step

# data = np.load('data/KFvorticity_Re40_N200_T500.npy')
data = np.load('data/KFvorticity_Re500_N1000_T500.npy')
data = torch.tensor(data, dtype=torch.float)
print(data.shape )

train_a = data[:ntrain,T_in-1:T_out-1,::sub,::sub].reshape(ntrain*iter,step,1,S,S)
train_u = data[:ntrain,T_in:T_out,::sub,::sub].reshape(ntrain*iter,step,1,S,S)
test_a = data[-ntest:,T_in-1:T_out-1,::sub,::sub].reshape(ntest*iter,step,1,S,S)
test_u = data[-ntest:,T_in:T_out,::sub,::sub].reshape(ntest*iter,step,1,S,S)

print(train_a.shape)
print(train_u.shape)



train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size2, shuffle=False)

t2 = default_timer()

print('preprocessing finished, time used:', t2-t1)
device = torch.device('cuda')

model = CLSTM().cuda()

print(model.count_params())
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)


lploss = LpLoss(size_average=False)
h1loss = HsLoss(k=1, group=False, size_average=False)
h2loss = HsLoss(k=2, group=False, size_average=False)
myloss = HsLoss(k=loss_k, group=loss_group, size_average=False)


for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        out = model(x)
        loss = myloss(out.reshape(batch_size*step, S, S, -1), y.reshape(batch_size*step, S, S, -1))
        train_l2 += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    test_l2 = 0
    test_h1 = 0
    test_h2 = 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            out = model(x)
            test_l2 += lploss(out.reshape(batch_size*step, S, S, -1), y.reshape(batch_size*step, S, S, -1)).item()
            test_h1 += h1loss(out.reshape(batch_size*step, S, S, -1), y.reshape(batch_size*step, S, S, -1)).item()
            test_h2 += h2loss(out.reshape(batch_size*step, S, S, -1), y.reshape(batch_size*step, S, S, -1)).item()

        t2 = default_timer()
        scheduler.step()
        print(ep, t2 - t1, train_l2 / (ntrain * T), test_l2 / (ntest * T), test_h1 / (ntest * T), test_h2 / (ntest * T))

torch.save(model, path_model)




model.eval()
test_a = test_a[0]

T = 10000
pred = torch.zeros(S,S,T)

x_in = test_a.reshape(1,step,1,S,S).cuda()
with torch.no_grad():
    for i in range(T):
        print(i)
        x_out = model(x_in)
        pred[:,:,i] = x_out[0,-1,0,:,:].view(S,S)
        x_in = torch.cat([x_in[:,1:,:,:,:], x_out[:,-1:,:,:,:]], dim=1)


scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()})


