import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import math
import h5py
import os

torch.set_default_tensor_type(torch.DoubleTensor)
EPS = 4.0 / 128


def compute_vort_uv(vort_pos, vort_value):
    """
    :param vort_pos: tensor(num_vortex, 2)
    :param vort_value: tensor(num_vortex)
    :return: uv(num_vortex, 2)
    """
    uv = torch.zeros_like(vort_pos)
    for i in range(len(vort_value)):
        for j in range(len(vort_value)):
            if i != j:
                r_ij2 = (vort_pos[i, :] - vort_pos[j, :]).norm() ** 2
                u = vort_value[j] * (vort_pos[j, 1] - vort_pos[i, 1]) / (r_ij2 * math.pi * 2) * (1.0 - (-r_ij2 / EPS ** 2).exp())
                uv[i, 0] = uv[i, 0] + u
                v = vort_value[j] * (vort_pos[i, 0] - vort_pos[j, 0]) / (r_ij2 * math.pi * 2) * (1.0 - (-r_ij2 / EPS ** 2).exp())
                uv[i, 1] = uv[i, 1] + v
    return uv


def rk4_vort(vort_pos0, vort_value, dt, length_t):
    n_steps = int(np.ceil(length_t / dt))
    h = length_t / n_steps
    for i in range(n_steps):
        uv0 = compute_vort_uv(vort_pos0, vort_value)
        vort_pos1 = vort_pos0 + 0.5 * h * uv0
        uv1 = compute_vort_uv(vort_pos1, vort_value)
        vort_pos2 = vort_pos0 + 0.5 * h * uv1
        uv2 = compute_vort_uv(vort_pos2, vort_value)
        vort_pos3 = vort_pos0 + h * uv2
        uv3 = compute_vort_uv(vort_pos3, vort_value)
        vort_pos0 = vort_pos0 + h * (uv0 + 2*uv1 + 2*uv2 + uv3) / 6
    return vort_pos0


def gen(data_type):
    dt = 0.01
    length_t = 0.1
    # dt = 0.1
    # length_t = 1
    num_samples = 10000 if data_type == 'train' else 2000
    data = []
    target = []

    for i in range(num_samples):
        num_vortex = np.random.randint(2, 7)
        vortex_value = torch.DoubleTensor(num_vortex).uniform_(0.5, 2) * np.random.choice([-1, 1])
        vortex_pos = torch.DoubleTensor(num_vortex, 2).uniform_(0, 5)
        # num_vortex = 2
        # vortex_pos = torch.zeros(num_vortex, 2)
        # vortex_pos[0, 0] = -2.5 + 2 * np.random.rand()
        # vortex_pos[0, 1] = -1.5 + 3 * np.random.rand()
        # vortex_pos[1, 0] = 0.5 + 2 * np.random.rand()
        # vortex_pos[1, 1] = -1.5 + 3 * np.random.rand()
        # vortex_value = torch.DoubleTensor(num_vortex).uniform_(1, 2) * np.random.choice([-1, 1])
        data.append(torch.cat((vortex_pos, vortex_value.unsqueeze(1)), dim=1).float())
        target.append(rk4_vort(vortex_pos, vortex_value, dt, length_t).float())

    data_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dynamics_data")
    if not os.path.exists(data_root):
        os.makedirs(data_root)

    hf = h5py.File(os.path.join(data_root, data_type + ".h5"), "w")
    for i in range(num_samples):
        hf.create_dataset('data_' + str(i), data=data[i].numpy().astype(np.float32))
        hf.create_dataset('target_' + str(i), data=target[i].numpy().astype(np.float32))
    hf.close()


if __name__ == '__main__':
    gen('train')
    gen('test')