import numpy as np
import struct
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from scipy.io import loadmat
import torch
import math
import h5py
import os
from detection import DetectionNet


def gen():
    device = 'cuda:0'
    device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
    net = DetectionNet()
    net.load_state_dict(torch.load("detection.pt", map_location=lambda storage, location: storage))
    net.to(device)
    net.eval()
    data_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "real_dynamics_data", "const_force")

    xyw1 = loadmat(data_root + "/xyw1.mat")['xyw1'].transpose(2, 1, 0).astype(np.float32)
    xyw1 = torch.from_numpy(xyw1)
    xyw2 = loadmat(data_root + "/xyw2.mat")['xyw2'].transpose(2, 1, 0).astype(np.float32)
    xyw2 = torch.from_numpy(xyw2)
    num_samples = xyw1.shape[0]
    num_particles = (xyw1[:, :, 2].abs() > 1e-5).sum(dim=1)

    data_list = []
    target_list = []

    with torch.no_grad():
        for i in range(num_samples):
            f = open(os.path.join(data_root, 'grid_data', 'N' + str(i+1).zfill(5) + 't001.dat'), 'rb')
            target_grid = np.asarray(struct.unpack('f' * (200 * 200), f.read(4 * 200 * 200))).reshape((1, 1, 200, 200))
            f.close()
            target_grid = torch.from_numpy(target_grid).float().to(device)

            target_exist, target_coord = net(target_grid.to(device))

            target_exist = target_exist > 0.5
            target_num_particle = target_exist.sum()
            if target_num_particle != num_particles[i]:
                continue

            target_exist, target_coord = target_exist.cpu(), target_coord.cpu()

            grid_size = 200
            grid_size8 = int(grid_size / 8)
            xv, yv = torch.meshgrid([torch.arange(0, grid_size), torch.arange(0, grid_size)])
            dx = math.pi * 2 / 512.
            dx8 = dx * 8
            grid_start = math.pi - 100 * dx
            grid = torch.cat((xv.unsqueeze(2), yv.unsqueeze(2)), dim=2) * dx
            range_size = 10
            target = torch.zeros(target_num_particle, 3)
            for j in range(grid_size8):
                for k in range(grid_size8):
                    if target_exist[0, 0, j, k]:
                        coord = (target_coord[0, :, j, k] + torch.tensor([j, k])) * dx8
                        old_coord = coord + 1
                        while (coord - old_coord).norm() > 1e-5:
                            old_coord = coord
                            cell = (coord / dx).long()
                            left = np.clip(cell[0] - range_size, a_min=0, a_max=199)
                            right = np.clip(cell[0] + range_size, a_min=0, a_max=199)
                            bottom = np.clip(cell[1] - range_size, a_min=0, a_max=199)
                            up = np.clip(cell[1] + range_size, a_min=0, a_max=199)
                            coord = (target_grid[0, 0, left:right, bottom:up].unsqueeze(-1) *
                                     grid[left:right, bottom:up, :]).reshape(-1, 2).sum(dim=0) / target_grid[0, 0, left:right, bottom:up].sum()
                        cell = (coord / dx).long()
                        left = np.clip(cell[0] - range_size, a_min=0, a_max=199)
                        right = np.clip(cell[0] + range_size, a_min=0, a_max=199)
                        bottom = np.clip(cell[1] - range_size, a_min=0, a_max=199)
                        up = np.clip(cell[1] + range_size, a_min=0, a_max=199)
                        value = target_grid[0, 0, left:right, bottom:up].sum() * dx**2
                        coord = coord[[1, 0]] + grid_start
                        index = (coord.unsqueeze(0) - xyw1[i][:, :2]).norm(dim=1).min(dim=0)[1]
                        target[index, :-1] = coord
                        target[index, -1] = value
            if (target.abs() < 1e-4).sum() > 0:
                continue
            data_list.append(xyw1[i, :num_particles[i], :])
            target_list.append(target)

    num_samples = len(data_list)
    print(num_samples)

    data_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "real_dynamics_data")

    split = int(num_samples*0.85)

    hf = h5py.File(os.path.join(data_root, "train.h5"), "w")
    for i in range(split):
        hf.create_dataset('data_' + str(i), data=data_list[i].numpy().astype(np.float32))
        hf.create_dataset('target_' + str(i), data=target_list[i].numpy().astype(np.float32))
    hf.close()

    hf = h5py.File(os.path.join(data_root, "test.h5"), "w")
    for i in range(split, num_samples):
        hf.create_dataset('data_' + str(i), data=data_list[i].numpy().astype(np.float32))
        hf.create_dataset('target_' + str(i), data=target_list[i].numpy().astype(np.float32))
    hf.close()


if __name__ == '__main__':
    gen()
