import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from scipy.io import loadmat
import torch.nn as nn
import random
import h5py
import os
import torch.utils.data
import math
import struct


def gen():
    # data_root = "real_dynamics_data/sin_force/"
    data_root = "real_dynamics_data/const_force/"
    xyw1 = loadmat(data_root + "xyw1.mat")['xyw1'].transpose(2, 1, 0).astype(np.float32)
    xyw1 = torch.from_numpy(xyw1)
    num_samples = xyw1.shape[0]
    num_particles = (xyw1[:, :, 2].abs() > 1e-5).sum(dim=1).long()

    grid_size = 200
    grid_size8 = int(grid_size/8)
    xv, yv = torch.meshgrid([torch.arange(0, grid_size), torch.arange(0, grid_size)])
    dx = math.pi * 2 / 512.
    dx8 = dx * 8
    grid_start = math.pi - 100*dx
    grid = torch.cat((xv.unsqueeze(2), yv.unsqueeze(2)), dim=2).reshape(-1, 2) * dx + grid_start

    data = []
    for i in range(num_samples):
        f = open(os.path.join(data_root, 'grid_data', 'N' + str(i + 1).zfill(5) + 't000.dat'), 'rb')
        data.append(torch.from_numpy(np.asarray(struct.unpack('f' * (200 * 200), f.read(4 * 200 * 200))).reshape((1, 1, 200, 200))))
        f.close()
    data = torch.cat(data, dim=0)
    exist = torch.zeros(num_samples, 1, grid_size8, grid_size8)
    coord = torch.zeros(num_samples, 2, grid_size8, grid_size8)
    for i in range(num_samples):
        vortex_pos = xyw1[i, :, [1, 0]]
        for j in range(num_particles[i]):
            vortex_cell = ((vortex_pos[j]-grid_start) / dx8).floor().long()
            exist[i, 0, vortex_cell[0], vortex_cell[1]] = 1
            coord[i, :, vortex_cell[0], vortex_cell[1]] = vortex_pos[j] / dx8 - vortex_cell

        # fig, ax = plt.subplots(nrows=2, ncols=1)
        # im = ax[0].imshow(data[i, 0, :, :])
        # plt.colorbar(im, ax=ax[0])
        # im = ax[1].imshow(exist[i, 0, :, :])
        # plt.colorbar(im, ax=ax[1])
        # plt.show()

    data_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "detection_data")
    if not os.path.exists(data_root):
        os.makedirs(data_root)

    split = int(num_samples * 0.85)

    hf = h5py.File(os.path.join(data_root, "train.h5"), "w")
    hf.create_dataset('data', data=data[:split, :, :, :].numpy().astype(np.float32))
    hf.create_dataset('exist', data=exist[:split, :, :, :].numpy().astype(np.float32))
    hf.create_dataset('coord', data=coord[:split, :, :, :].numpy().astype(np.float32))
    hf.close()

    hf = h5py.File(os.path.join(data_root, "test.h5"), "w")
    hf.create_dataset('data', data=data[split:, :, :, :].numpy().astype(np.float32))
    hf.create_dataset('exist', data=exist[split:, :, :, :].numpy().astype(np.float32))
    hf.create_dataset('coord', data=coord[split:, :, :, :].numpy().astype(np.float32))
    hf.close()


if __name__ == '__main__':
    gen()
