import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch.nn.functional as F
import torch.nn as nn
import random
import h5py
import os
import math
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
from torch.utils.tensorboard import SummaryWriter
from scipy.io import loadmat
import struct

device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
# data_root = "real_dynamics_data/sin_force/"
data_root = "real_dynamics_data/const_force/"


def get_data(data_type, batch_size):
    # data_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dynamics_data")
    # f = h5py.File(os.path.join(data_root, data_type + '.h5'), 'r')
    #
    # num_samples = 10000 if data_type == 'train' else 2000
    #
    # data_list = [
    #     Data(
    #         x=torch.from_numpy(f['data_' + str(i)][:]),
    #         y=torch.from_numpy(f['target_' + str(i)][:]),
    #         edge_index=torch.tensor([[j, k]
    #                                  for j in range(f['data_' + str(i)][:].shape[0])
    #                                  for k in range(f['data_' + str(i)][:].shape[0])
    #                                  if j != k]).transpose(0, 1).contiguous()
    #     ) for i in range(num_samples)
    # ]

    xyw1 = loadmat(data_root + "xyw1.mat")['xyw1'].transpose(2, 1, 0).astype(np.float32)
    xyw2 = loadmat(data_root + "xyw2.mat")['xyw2'].transpose(2, 1, 0).astype(np.float32)
    xyw1, xyw2 = torch.from_numpy(xyw1), torch.from_numpy(xyw2)
    num_samples = xyw1.shape[0]
    num_particles = (xyw1[:, :, 2].abs() > 1e-5).sum(dim=1)

    split = int(num_samples * 0.8)
    start = 0 if data_type == 'train' else split
    end = split if data_type == 'train' else num_samples
    data_list = [
        Data(
            x=xyw1[i, :num_particles[i], :],
            y=xyw2[i, :num_particles[i], :2],
            edge_index=torch.tensor([[j, k]
                                     for j in range(num_particles[i])
                                     for k in range(num_particles[i])
                                     if j != k]).transpose(0, 1).contiguous()
        ) for i in range(start, end)
    ]

    return DataLoader(data_list, batch_size=batch_size, shuffle=True)


class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Linear(inchannel, outchannel),
            nn.ReLU(inplace=True),
            nn.Linear(outchannel, outchannel),
        )
        self.shortcut = nn.Sequential()
        if inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Linear(inchannel, outchannel),
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class DisplacementLayer(MessagePassing):
    def __init__(self):
        super(DisplacementLayer, self).__init__(aggr='add')
        self.dis = nn.Sequential(ResidualBlock(4, 64),
                                 ResidualBlock(64, 64),
                                 ResidualBlock(64, 128),
                                 ResidualBlock(128, 128),
                                 ResidualBlock(128, 256),
                                 ResidualBlock(256, 256),
                                 nn.Linear(256, 2, bias=False))
        self.rot = nn.Parameter(torch.zeros(1), requires_grad=True)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_i, x_j):
        diff = x_j[:, :-1] - x_i[:, :-1]
        dist = diff.norm(dim=1, keepdim=True)
        feature = torch.cat((diff, dist, x_j[:, 2].unsqueeze(1)), dim=1)
        displace = self.dis(feature)
        return displace

    def update(self, aggr_out):
        return aggr_out


class RealDynamicsNet(nn.Module):
    def __init__(self):
        super(RealDynamicsNet, self).__init__()
        self.neighbor_influence = DisplacementLayer()
        self.global_influence = nn.Sequential(ResidualBlock(3, 64),
                                              ResidualBlock(64, 64),
                                              ResidualBlock(64, 128),
                                              ResidualBlock(128, 128),
                                              ResidualBlock(128, 256),
                                              ResidualBlock(256, 256),
                                              nn.Linear(256, 2, bias=False))

    def cal_uv(self, x0, edge_index):
        neighbor_influence = self.neighbor_influence(x0, edge_index)
        global_influence = self.global_influence(x0)
        return neighbor_influence + global_influence

    def forward(self, x0, edge_index, dt, length_t):
        n_steps = int(np.ceil(length_t / dt))
        h = length_t / n_steps
        for i in range(n_steps):
            uv0 = self.cal_uv(x0, edge_index)
            x1 = torch.cat((x0[:, :-1] + 0.5 * h * uv0, x0[:, -1].unsqueeze(1)), dim=1)
            uv1 = self.cal_uv(x1, edge_index)
            x2 = torch.cat((x0[:, :-1] + 0.5 * h * uv1, x0[:, -1].unsqueeze(1)), dim=1)
            uv2 = self.cal_uv(x2, edge_index)
            x3 = torch.cat((x0[:, :-1] + h * uv2, x0[:, -1].unsqueeze(1)), dim=1)
            uv3 = self.cal_uv(x3, edge_index)
            x0 = torch.cat((x0[:, :-1] + h * (uv0 + 2 * uv1 + 2 * uv2 + uv3) / 6, x0[:, -1].unsqueeze(1)), dim=1)
        return x0[:, :-1]


n_epoch = 2000
batch_size = 64
lr = 1e-3
lr_gamma = 0.8
lr_step = 20
dt = 0.1
length_t = 0.2
net = RealDynamicsNet()
net.to(device)
mse_loss = torch.nn.functional.l1_loss


def train():
    writer = SummaryWriter('runs/real_dynamics')
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=lr_gamma)
    data_loader = get_data('train', batch_size)
    test_data_loader = get_data('test', batch_size)
    lowest_test_loss = 99999
    for i in range(n_epoch):
        net.train()
        train_loss = 0
        train_sample = 0
        for batch_index, data_batch in enumerate(data_loader):
            optimizer.zero_grad()
            data, edge_index, target = data_batch.x.to(device), data_batch.edge_index.to(device), data_batch.y.to(device)
            pred = net(data, edge_index, dt, length_t)
            loss = mse_loss(pred, target)
            loss.backward()
            train_loss += loss.detach().cpu().item()
            train_sample += 1
            optimizer.step()
        scheduler.step()

        test_loss = 0
        test_sample = 0
        net.eval()
        with torch.no_grad():
            for batch_index, data_batch in enumerate(test_data_loader):
                data, edge_index, target = data_batch.x.to(device), data_batch.edge_index.to(device), data_batch.y.to(device)
                pred = net(data, edge_index, dt, length_t)
                loss = mse_loss(pred, target)
                test_loss += loss.detach().cpu().item()
                test_sample += 1
        # print("Epoch: {0} | Train Loss: {1} | Test Loss: {2}".
        #       format(i + 1,
        #              format(train_loss / train_sample, '.2e'),
        #              format(test_loss / test_sample, '.2e'), ))
        writer.add_scalars('loss', {'train_loss': train_loss / train_sample,
                                    'test_loss': test_loss / test_sample}, i + 1)
        writer.flush()
        if lowest_test_loss > test_loss / test_sample:
            torch.save(net.state_dict(), "real_dynamics_const.pt")
            lowest_test_loss = test_loss / test_sample
    writer.close()


def test():
    net.load_state_dict(torch.load("real_dynamics_const.pt", map_location=lambda storage, location: storage))

    vort_pos = math.pi + torch.tensor([-0.3, -0.5,
                                       0.3, -0.4,
                                       0.3, 0.5,
                                       -0.3, 0.4]).reshape(4, 2).to(device)
    num_particle = vort_pos.shape[0]
    vort_value = 0.75 * torch.tensor([1, -1, 1., -1]).reshape(4, 1).to(device)
    x = torch.cat((vort_pos, vort_value), dim=1)
    edge_index = torch.tensor([[j, k]
                               for j in range(num_particle)
                               for k in range(num_particle)
                               if j != k]).transpose(0, 1).contiguous()

    vort_poses = [vort_pos]
    net.to('cpu')
    net.eval()
    with torch.no_grad():
        for i in range(50):
            vort_pos = net(x, edge_index, dt, length_t)
            x = torch.cat((vort_pos, vort_value), dim=1)
            vort_poses.append(vort_pos)
    vort_poses = torch.cat(vort_poses)

    plt.scatter(vort_poses[:, 0], vort_poses[:, 1], c='red', s=10)
    Ngrid = 1024
    f = open(data_root + "tecplot_vor.dat", 'rb')
    struct.unpack('f' * (57), f.read(4 * 57))
    meshx = struct.unpack('f' * (Ngrid * Ngrid), f.read(4 * Ngrid * Ngrid))
    meshy = struct.unpack('f' * (Ngrid * Ngrid), f.read(4 * Ngrid * Ngrid))
    vor = struct.unpack('f' * (Ngrid * Ngrid), f.read(4 * Ngrid * Ngrid))
    vor = np.asarray(vor).reshape(Ngrid, Ngrid)
    f.close()
    d = 5
    levels = np.arange(-50, 50 + d, d)
    meshx = np.asarray(meshx).reshape(Ngrid, Ngrid)
    meshy = np.asarray(meshy).reshape(Ngrid, Ngrid)
    plt.contourf(meshx, meshy, vor, cmap='RdYlBu', levels=levels, extend='both', alpha=0.1)
    plt.colorbar()
    plt.show()


if __name__ == '__main__':
    # train()
    test()