import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch.nn.functional as F
import torch.nn as nn
import random
import h5py
import os
import math
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
from torch.utils.tensorboard import SummaryWriter
from scipy.io import loadmat
import struct

device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')


def get_data(data_type, batch_size):
    # data_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dynamics_data")
    # f = h5py.File(os.path.join(data_root, data_type + '.h5'), 'r')
    #
    # num_samples = int(len(f.keys()) / 2)
    #
    # data_list = [
    #     Data(
    #         x=torch.from_numpy(f['data_' + str(i)][:]),
    #         y=torch.from_numpy(f['target_' + str(i)][:]),
    #         edge_index=torch.tensor([[j, k]
    #                                  for j in range(f['data_' + str(i)][:].shape[0])
    #                                  for k in range(f['data_' + str(i)][:].shape[0])
    #                                  if j != k]).transpose(0, 1).contiguous()
    #     ) for i in range(num_samples)
    # ]

    x00 = loadmat("dynamics_data/x00.mat")['x00'].reshape(-1).astype(np.float32)
    y00 = loadmat("dynamics_data/y00.mat")['y00'].reshape(-1).astype(np.float32)
    w00 = loadmat("dynamics_data/w00.mat")['w00'].reshape(-1).astype(np.float32)

    x10 = loadmat("dynamics_data/x10.mat")['x10'].reshape(-1).astype(np.float32)
    y10 = loadmat("dynamics_data/y10.mat")['y10'].reshape(-1).astype(np.float32)
    w10 = loadmat("dynamics_data/w10.mat")['w10'].reshape(-1).astype(np.float32)

    x01 = loadmat("dynamics_data/x01.mat")['x01'].reshape(-1).astype(np.float32)
    y01 = loadmat("dynamics_data/y01.mat")['y01'].reshape(-1).astype(np.float32)
    w01 = loadmat("dynamics_data/w01.mat")['w01'].reshape(-1).astype(np.float32)

    x11 = loadmat("dynamics_data/x11.mat")['x11'].reshape(-1).astype(np.float32)
    y11 = loadmat("dynamics_data/y11.mat")['y11'].reshape(-1).astype(np.float32)
    w11 = loadmat("dynamics_data/w11.mat")['w11'].reshape(-1).astype(np.float32)

    x00, x10, x01, x11 = torch.from_numpy(x00), torch.from_numpy(x10), torch.from_numpy(x01), torch.from_numpy(x11)
    y00, y10, y01, y11 = torch.from_numpy(y00), torch.from_numpy(y10), torch.from_numpy(y01), torch.from_numpy(y11)
    w00, w10, w01, w11 = torch.from_numpy(w00), torch.from_numpy(w10), torch.from_numpy(w01), torch.from_numpy(w11)

    num_samples = x00.shape[0]
    split = int(num_samples * 0.9)

    start = 0 if data_type == 'train' else split
    end = split if data_type == 'train' else num_samples
    data_list = [
        Data(
            x=torch.tensor([x00[i], y00[i], w00[i], x10[i], y10[i], w10[i]]).reshape(2, 3),
            y=torch.tensor([x01[i], y01[i], w01[i], x11[i], y11[i], w11[i]]).reshape(2, 3)[:, :2],
            edge_index=torch.tensor([[j, k]
                                     for j in range(2)
                                     for k in range(2)
                                     if j != k]).transpose(0, 1).contiguous()
        ) for i in range(start, end)
    ]

    return DataLoader(data_list, batch_size=batch_size, shuffle=True)


class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Linear(inchannel, outchannel),
            nn.ReLU(inplace=True),
            nn.Linear(outchannel, outchannel),
        )
        self.shortcut = nn.Sequential()
        if inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Linear(inchannel, outchannel),
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class DisplacementLayer(MessagePassing):
    def __init__(self):
        super(DisplacementLayer, self).__init__(aggr='add')
        self.dis = nn.Sequential(ResidualBlock(4, 64),
                                 ResidualBlock(64, 64),
                                 ResidualBlock(64, 128),
                                 ResidualBlock(128, 128),
                                 ResidualBlock(128, 256),
                                 ResidualBlock(256, 256),
                                 nn.Linear(256, 2))
        self.rot = nn.Parameter(torch.zeros(1), requires_grad=True)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_i, x_j):
        diff = x_j[:, :-1] - x_i[:, :-1]
        dist = diff.norm(dim=1, keepdim=True)
        feature = torch.cat((diff, dist, x_j[:, 2].unsqueeze(1)), dim=1)
        displace = self.dis(feature)
        # EPS = 4.0 / 128
        # r_ij2 = (x_i[:, :-1] - x_j[:, :-1]).norm(dim=1) ** 2
        # u = x_j[:, 2] * (x_j[:, 1] - x_i[:, 1]) / (r_ij2 * math.pi * 2) * (1.0 - (-r_ij2 / EPS ** 2).exp())
        # v = x_j[:, 2] * (x_i[:, 0] - x_j[:, 0]) / (r_ij2 * math.pi * 2) * (1.0 - (-r_ij2 / EPS ** 2).exp())
        # return torch.cat((u.unsqueeze(1), v.unsqueeze(1)), dim=1)
        return displace

    def update(self, aggr_out):
        return aggr_out


class DynamicsNet(nn.Module):
    def __init__(self):
        super(DynamicsNet, self).__init__()
        self.edge_conv = DisplacementLayer()

    def forward(self, x0, edge_index, dt, length_t):
        n_steps = int(np.ceil(length_t / dt))
        h = length_t / n_steps
        for i in range(n_steps):
            uv0 = self.edge_conv(x0, edge_index)
            x1 = torch.cat((x0[:, :-1] + 0.5 * h * uv0, x0[:, -1].unsqueeze(1)), dim=1)
            uv1 = self.edge_conv(x1, edge_index)
            x2 = torch.cat((x0[:, :-1] + 0.5 * h * uv1, x0[:, -1].unsqueeze(1)), dim=1)
            uv2 = self.edge_conv(x2, edge_index)
            x3 = torch.cat((x0[:, :-1] + h * uv2, x0[:, -1].unsqueeze(1)), dim=1)
            uv3 = self.edge_conv(x3, edge_index)
            x0 = torch.cat((x0[:, :-1] + h * (uv0 + 2 * uv1 + 2 * uv2 + uv3) / 6, x0[:, -1].unsqueeze(1)), dim=1)
        return x0[:, :-1]


n_epoch = 2000
batch_size = 64
lr = 1e-3
lr_gamma = 0.8
lr_step = 20
dt = 0.1
length_t = 0.2
net = DynamicsNet()
net.to(device)
mse_loss = torch.nn.functional.l1_loss


def train():
    writer = SummaryWriter('../runs/dynamics')
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=lr_gamma)
    data_loader = get_data('train', batch_size)
    test_data_loader = get_data('test', batch_size)
    lowest_test_loss = 99999
    for i in range(n_epoch):
        net.train()
        train_loss = 0
        train_sample = 0
        for batch_index, data_batch in enumerate(data_loader):
            optimizer.zero_grad()
            data, edge_index, target = data_batch.x.to(device), data_batch.edge_index.to(device), data_batch.y.to(device)
            pred = net(data, edge_index, dt, length_t)
            loss = mse_loss(pred, target)
            loss.backward()
            train_loss += loss.detach().cpu().item()
            train_sample += 1
            optimizer.step()
        scheduler.step()

        test_loss = 0
        test_sample = 0
        net.eval()
        with torch.no_grad():
            for batch_index, data_batch in enumerate(test_data_loader):
                data, edge_index, target = data_batch.x.to(device), data_batch.edge_index.to(device), data_batch.y.to(device)
                pred = net(data, edge_index, dt, length_t)
                loss = mse_loss(pred, target)
                test_loss += loss.detach().cpu().item()
                test_sample += 1
        # print("Epoch: {0} | Train Loss: {1} | Test Loss: {2}".
        #       format(i + 1,
        #              format(train_loss / train_sample, '.2e'),
        #              format(test_loss / test_sample, '.2e'), ))
        writer.add_scalars('loss', {'train_loss': train_loss / train_sample,
                                    'test_loss': test_loss / test_sample}, i + 1)
        writer.flush()
        if lowest_test_loss > test_loss / test_sample:
            torch.save(net.state_dict(), "dynamics.pt")
            lowest_test_loss = test_loss / test_sample
    writer.close()


def test():
    num_tracer = 2000
    net.load_state_dict(torch.load("dynamics.pt", map_location=lambda storage, location: storage))

    def compute_tracer_uv(xy, vort):
        """
        :param xy: tensor(num_particle, 2)
        :param vort_pos: tensor(num_vortex, 2)
        :param vort_value: tensor(num_vortex)
        :return: uv(num_particle, 2)
        """
        EPS = math.pi / 64
        vort_pos = vort[:, :2]
        vort_value = vort[:, 2]
        r_ij2 = (xy.unsqueeze(1) - vort_pos).norm(dim=2) ** 2  # shape (num_particle, num_vortex)
        u = vort_value.unsqueeze(0) * (vort_pos[:, 1] - xy[:, 1].unsqueeze(1)) / (r_ij2 * math.pi * 2) * (
                1.0 - (-r_ij2 / EPS ** 2).exp())  # shape (num_particle, num_vortex)
        u = u.sum(dim=1, keepdim=True)  # shape (num_particle, 1)
        v = vort_value.unsqueeze(0) * (xy[:, 0].unsqueeze(1) - vort_pos[:, 0]) / (r_ij2 * math.pi * 2) * (
                1.0 - (-r_ij2 / EPS ** 2).exp())
        v = v.sum(dim=1, keepdim=True)
        return torch.cat((u, v), dim=1)

    def rk4_tracer(xy0, vort, dt, length_t):
        n_steps = int(np.ceil(length_t / dt))
        h = length_t / n_steps
        for i in range(n_steps):
            uv0 = compute_tracer_uv(xy0, vort)
            xy1 = xy0 + 0.5 * h * uv0
            uv1 = compute_tracer_uv(xy1, vort)
            xy2 = xy0 + 0.5 * h * uv1
            uv2 = compute_tracer_uv(xy2, vort)
            xy3 = xy0 + h * uv2
            uv3 = compute_tracer_uv(xy3, vort)
            xy0 = xy0 + h * (uv0 + 2 * uv1 + 2 * uv2 + uv3) / 6
        return xy0

    pos_x = torch.FloatTensor(num_tracer, 1).uniform_(-0.5, 0.5)
    pos_y = torch.FloatTensor(num_tracer, 1).uniform_(-1.5, 1.5)
    xy = torch.cat((pos_x, pos_y), dim=1)
    vort_pos = torch.tensor([[0, 1],
                             [0, -1],
                             [0, 0.3],
                             [0, -0.3]
                             ])
    vort_value = torch.tensor([1, -1, 1, -1.0]).reshape(4, 1) * 0.8
    x = torch.cat((vort_pos, vort_value), dim=1)
    edge_index = torch.tensor([[j, k]
                               for j in range(4)
                               for k in range(4)
                               if j != k]).transpose(0, 1).contiguous()

    xys = [xy]
    vorts = [vort_pos]
    net.to(device)
    net.eval()
    with torch.no_grad():
        for i in range(200):
            xy = rk4_tracer(xy, x, dt, length_t)
            vort_pos = net(x, edge_index, dt, length_t)
            x = torch.cat((vort_pos, vort_value), dim=1)
            xys.append(xy)
            vorts.append(vort_pos)

    fig, ax = plt.subplots()
    sc1 = ax.scatter(xys[0][:, 0], xys[0][:, 1], c='blue', s=1)
    sc2 = ax.scatter(vorts[0][:, 0], vorts[0][:, 1], c='red', s=10)
    plt.xlim(-1, 10)
    plt.ylim(-2, 2)

    def animate(i):
        sc1.set_offsets(xys[i])
        sc2.set_offsets(vorts[i])

    ani = animation.FuncAnimation(fig, animate, frames=len(xys), interval=5)
    plt.show()


def test2():
    net.load_state_dict(torch.load("dynamics.pt", map_location=lambda storage, location: storage))

    vort_pos = torch.tensor([2.741593, 2.541593,
                             3.541593, 2.541593,
                             3.541593, 3.741593,
                             2.741593, 3.741593]).reshape(4, 2).to(device)
    vort_value = 0.75 * torch.tensor([1, -1, 1., -1]).reshape(4, 1).to(device)
    x = torch.cat((vort_pos, vort_value), dim=1)
    edge_index = torch.tensor([[j, k]
                               for j in range(4)
                               for k in range(4)
                               if j != k]).transpose(0, 1).contiguous()

    vort_poses = [vort_pos]
    net.to('cpu')
    net.eval()
    with torch.no_grad():
        for i in range(70):
            vort_pos = net(x, edge_index, dt, length_t)
            x = torch.cat((vort_pos, vort_value), dim=1)
            vort_poses.append(vort_pos)
    vort_poses = torch.cat(vort_poses)

    plt.scatter(vort_poses[:, 0], vort_poses[:, 1], c='red', s=10)
    Ngrid = 1024
    f = open('dynamics_data/tecplot_vor.dat', 'rb')
    struct.unpack('f' * (57), f.read(4 * 57))
    meshx = struct.unpack('f' * (Ngrid * Ngrid), f.read(4 * Ngrid * Ngrid))
    meshy = struct.unpack('f' * (Ngrid * Ngrid), f.read(4 * Ngrid * Ngrid))
    vor = struct.unpack('f' * (Ngrid * Ngrid), f.read(4 * Ngrid * Ngrid))
    vor = np.asarray(vor).reshape(Ngrid, Ngrid)
    f.close()
    d = 5
    levels = np.arange(-50, 50 + d, d)
    meshx = np.asarray(meshx).reshape(Ngrid, Ngrid)
    meshy = np.asarray(meshy).reshape(Ngrid, Ngrid)
    plt.contourf(meshx, meshy, vor, cmap='RdYlBu', levels=levels, extend='both', alpha=0.1)
    plt.colorbar()
    plt.show()


def test3():
    net.load_state_dict(torch.load("dynamics.pt", map_location=lambda storage, location: storage))

    vort_pos = math.pi + torch.tensor([-0.4, -0.6,
                                       0.4, 0.6]).reshape(2, 2)
    vort_value = 0.75 * torch.tensor([1., 1.]).reshape(2, 1)
    x = torch.cat((vort_pos, vort_value), dim=1)
    edge_index = torch.tensor([[j, k]
                               for j in range(2)
                               for k in range(2)
                               if j != k]).transpose(0, 1).contiguous()

    vort_poses = [vort_pos]
    net.to('cpu')
    net.eval()
    with torch.no_grad():
        for i in range(100):
            vort_pos = net(x, edge_index, dt, length_t)
            x = torch.cat((vort_pos, vort_value), dim=1)
            vort_poses.append(vort_pos)
    vort_poses = torch.cat(vort_poses)

    plt.scatter(vort_poses[:, 0], vort_poses[:, 1], c='red', s=10)
    Ngrid = 1024
    f = open('dynamics_data/tecplot_vor2.dat', 'rb')
    struct.unpack('f' * (57), f.read(4 * 57))
    meshx = struct.unpack('f' * (Ngrid * Ngrid), f.read(4 * Ngrid * Ngrid))
    meshy = struct.unpack('f' * (Ngrid * Ngrid), f.read(4 * Ngrid * Ngrid))
    vor = struct.unpack('f' * (Ngrid * Ngrid), f.read(4 * Ngrid * Ngrid))
    vor = np.asarray(vor).reshape(Ngrid, Ngrid)
    f.close()
    d = 5
    levels = np.arange(-50, 50 + d, d)
    meshx = np.asarray(meshx).reshape(Ngrid, Ngrid)
    meshy = np.asarray(meshy).reshape(Ngrid, Ngrid)
    plt.contourf(meshx, meshy, vor, cmap='RdYlBu', levels=levels, extend='both', alpha=0.1)
    plt.colorbar()
    plt.show()


if __name__ == '__main__':
    # train()
    test()
    # test2()
    # test3()