import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch.nn.functional as F
import torch.nn as nn
import random
import h5py
import os
import math
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
#from torch.utils.tensorboard import SummaryWriter
from scipy.io import loadmat

device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')

def to_np(x):
    return x.detach().cpu().numpy()

def scatter_animation(pos,vor,minx=-1,maxx=25,miny=-2,maxy=2,pt=0.01):
    lenp = len(pos)
    plt.figure(figsize=(maxx-minx, maxy-miny))
    for i in range(lenp):
        plt.clf()
        plt.xlim(minx, maxx)
        plt.ylim(miny, maxy)
        posi = pos[i]
        vori = vor[i]
        posi = to_np(posi)
        vori = to_np(vori)
        plt.scatter(posi[:, 0], posi[:, 1], c='#2E86C1' , s=0.01)
        plt.scatter(vori[:, 0], vori[:, 1], c='#21618C',  s=350)
        plt.show()
        plt.pause(pt)

def get_data(data_type, batch_size):
    # data_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dynamics_data")
    # f = h5py.File(os.path.join(data_root, data_type + '.h5'), 'r')
    #
    # num_samples = 10000 if data_type == 'train' else 2000
    #
    # data_list = [
    #     Data(
    #         x=torch.from_numpy(f['data_' + str(i)][:]),
    #         y=torch.from_numpy(f['target_' + str(i)][:]),
    #         edge_index=torch.tensor([[j, k]
    #                                  for j in range(f['data_' + str(i)][:].shape[0])
    #                                  for k in range(f['data_' + str(i)][:].shape[0])
    #                                  if j != k]).transpose(0, 1).contiguous()
    #     ) for i in range(num_samples)
    # ]

    x00 = loadmat("dynamics_data/x00.mat")['x00'].reshape(-1).astype(np.float32)
    y00 = loadmat("dynamics_data/y00.mat")['y00'].reshape(-1).astype(np.float32)
    w00 = loadmat("dynamics_data/w00.mat")['w00'].reshape(-1).astype(np.float32)

    x10 = loadmat("dynamics_data/x10.mat")['x10'].reshape(-1).astype(np.float32)
    y10 = loadmat("dynamics_data/y10.mat")['y10'].reshape(-1).astype(np.float32)
    w10 = loadmat("dynamics_data/w10.mat")['w10'].reshape(-1).astype(np.float32)

    x01 = loadmat("dynamics_data/x01.mat")['x01'].reshape(-1).astype(np.float32)
    y01 = loadmat("dynamics_data/y01.mat")['y01'].reshape(-1).astype(np.float32)
    w01 = loadmat("dynamics_data/w01.mat")['w01'].reshape(-1).astype(np.float32)

    x11 = loadmat("dynamics_data/x11.mat")['x11'].reshape(-1).astype(np.float32)
    y11 = loadmat("dynamics_data/y11.mat")['y11'].reshape(-1).astype(np.float32)
    w11 = loadmat("dynamics_data/w11.mat")['w11'].reshape(-1).astype(np.float32)

    coeff = np.random.uniform(0.1, 5, len(x00)).astype(np.float32)
    # x00 = x00 * coeff
    # y00 = y00 * coeff
    # w00 = w00 * coeff**2
    # x10 = x10 * coeff
    # y10 = y10 * coeff
    # w10 = w10 * coeff**2
    # x01 = x01 * coeff
    # y01 = y01 * coeff
    # w01 = w01 * coeff**2
    # x11 = x11 * coeff
    # y11 = y11 * coeff
    # w11 = w11 * coeff**2

    x00, x10, x01, x11 = torch.from_numpy(x00), torch.from_numpy(x10), torch.from_numpy(x01), torch.from_numpy(x11)
    y00, y10, y01, y11 = torch.from_numpy(y00), torch.from_numpy(y10), torch.from_numpy(y01), torch.from_numpy(y11)
    w00, w10, w01, w11 = torch.from_numpy(w00), torch.from_numpy(w10), torch.from_numpy(w01), torch.from_numpy(w11)

    # dist = torch.sqrt((x00-x10)**2 + (y00-y10)**2)
    # print(dist.min(), dist.max())

    start = 0 if data_type == 'train' else 3600
    end = 3600 if data_type == 'train' else 4000
    data_list = [
        Data(
            x=torch.tensor([x00[i], y00[i], w00[i], x10[i], y10[i], w10[i]]).reshape(2, 3),
            y=torch.tensor([x01[i], y01[i], w01[i], x11[i], y11[i], w11[i]]).reshape(2, 3)[:, :2],
            edge_index=torch.tensor([[j, k]
                                     for j in range(2)
                                     for k in range(2)
                                     if j != k]).transpose(0, 1).contiguous()
        ) for i in range(start, end)
    ]

    return DataLoader(data_list, batch_size=batch_size, shuffle=True)


class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Linear(inchannel, outchannel),
            # nn.BatchNorm1d(outchannel),
            nn.ReLU(inplace=True),
            nn.Linear(outchannel, outchannel),
            # nn.BatchNorm1d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Linear(inchannel, outchannel),
                # nn.BatchNorm1d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class DisplacementLayer(MessagePassing):
    def __init__(self):
        super(DisplacementLayer, self).__init__(aggr='add')
        self.dis = nn.Sequential(ResidualBlock(2, 32),
                                 ResidualBlock(32, 64),
                                 ResidualBlock(64, 64),
                                 ResidualBlock(64, 128),
                                 ResidualBlock(128, 128),
                                 nn.Linear(128, 1))
        self.rot = nn.Parameter(torch.zeros(1), requires_grad=True)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x, edge_index):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_i, x_j):
        diff = x_j[:, :-1] - x_i[:, :-1]
        dist = diff.norm(dim=1, keepdim=True)
        feature = torch.cat((dist, x_j[:, -1].unsqueeze(1)), dim=1)
        displace = self.dis(feature) * diff / 5
        displace = displace @ torch.cat([torch.cos(self.rot), torch.sin(self.rot),
                                         -torch.sin(self.rot), torch.cos(self.rot)]).reshape(2, 2).to(x_i)
        # EPS = 4.0 / 128
        # r_ij2 = (x_i[:, :-1] - x_j[:, :-1]).norm(dim=1) ** 2
        # u = x_j[:, 2] * (x_j[:, 1] - x_i[:, 1]) / (r_ij2 * math.pi * 2) * (1.0 - (-r_ij2 / EPS ** 2).exp())
        # v = x_j[:, 2] * (x_i[:, 0] - x_j[:, 0]) / (r_ij2 * math.pi * 2) * (1.0 - (-r_ij2 / EPS ** 2).exp())
        # return torch.cat((u.unsqueeze(1), v.unsqueeze(1)), dim=1)
        return displace

    def update(self, aggr_out):
        return aggr_out


class DynamicsNet(nn.Module):
    def __init__(self):
        super(DynamicsNet, self).__init__()
        self.edge_conv = DisplacementLayer()

    def forward(self, x0, edge_index, dt, length_t):
        n_steps = int(np.ceil(length_t / dt))
        h = length_t / n_steps
        for i in range(n_steps):
            uv0 = self.edge_conv(x0, edge_index)
            x1 = torch.cat((x0[:, :-1] + 0.5 * h * uv0, x0[:, -1].unsqueeze(1)), dim=1)
            uv1 = self.edge_conv(x1, edge_index)
            x2 = torch.cat((x0[:, :-1] + 0.5 * h * uv1, x0[:, -1].unsqueeze(1)), dim=1)
            uv2 = self.edge_conv(x2, edge_index)
            x3 = torch.cat((x0[:, :-1] + h * uv2, x0[:, -1].unsqueeze(1)), dim=1)
            uv3 = self.edge_conv(x3, edge_index)
            x0 = torch.cat((x0[:, :-1] + h * (uv0 + 2 * uv1 + 2 * uv2 + uv3) / 6, x0[:, -1].unsqueeze(1)), dim=1)
        return x0[:, :-1]


n_epoch = 2000
batch_size = 64
lr = 1e-3
lr_gamma = 0.8
lr_step = 20
dt = 0.001
num_tracer = 80000
length_t = 1
nt = 100
net = DynamicsNet()
net.to(device)
mse_loss = torch.nn.functional.l1_loss


def train():
#    writer = SummaryWriter('runs/dynamics')
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=lr_gamma)
    data_loader = get_data('train', batch_size)
    test_data_loader = get_data('test', batch_size)
    lowest_test_loss = 99999
    for i in range(n_epoch):
        net.train()
        train_loss = 0
        train_sample = 0
        for batch_index, data_batch in enumerate(data_loader):
            optimizer.zero_grad()
            data, edge_index, target = data_batch.x.to(device), data_batch.edge_index.to(device), data_batch.y.to(device)
            pred = net(data, edge_index, dt, length_t)
            loss = mse_loss(pred, target)
            loss.backward()
            train_loss += loss.detach().cpu().item()
            train_sample += 1
            optimizer.step()
        scheduler.step()

        test_loss = 0
        test_sample = 0
        net.eval()
        with torch.no_grad():
            for batch_index, data_batch in enumerate(test_data_loader):
                data, edge_index, target = data_batch.x.to(device), data_batch.edge_index.to(device), data_batch.y.to(device)
                pred = net(data, edge_index, dt, length_t)
                loss = mse_loss(pred, target)
                test_loss += loss.detach().cpu().item()
                test_sample += 1
        print("Epoch: {0} | Train Loss: {1} | Test Loss: {2}".
               format(i + 1,
                      format(train_loss / train_sample, '.2e'),
                      format(test_loss / test_sample, '.2e'), ))
        #writer.add_scalars('loss', {'train_loss': train_loss / train_sample,
        #                            'test_loss': test_loss / test_sample}, i + 1)
        #writer.flush()
        if lowest_test_loss > test_loss / test_sample:
            torch.save(net.state_dict(), "dynamics.pt")
            lowest_test_loss = test_loss / test_sample
    #writer.close()


def test():

    net.load_state_dict(torch.load("dynamics.pt", map_location=lambda storage, location: storage))
    net.to('cpu')
    # def compute_tracer_uv(particle_pos, vort):
    #     uv = 0
    #     for j in range(vort_pos.shape[0]):
    #         uv = uv + net.edge_conv.message(torch.cat((particle_pos, torch.zeros(num_tracer, 1)), dim=1),
    #                                         vort[j, :].unsqueeze(0).repeat(num_tracer, 1))
    #     return uv

    def compute_tracer_uv(xy, vort):
        """
        :param xy: tensor(num_particle, 2)
        :param vort_pos: tensor(num_vortex, 2)
        :param vort_value: tensor(num_vortex)
        :return: uv(num_particle, 2)
        """
        EPS = math.pi / 64
        vort_pos = vort[:, :2]
        vort_value = vort[:, 2]
        r_ij2 = (xy.unsqueeze(1) - vort_pos).norm(dim=2) ** 2  # shape (num_particle, num_vortex)
        u = vort_value.unsqueeze(0) * (vort_pos[:, 1] - xy[:, 1].unsqueeze(1)) / (r_ij2 * math.pi * 2) * (
                1.0 - (-r_ij2 / EPS ** 2).exp())  # shape (num_particle, num_vortex)
        u = u.sum(dim=1, keepdim=True)  # shape (num_particle, 1)
        v = vort_value.unsqueeze(0) * (xy[:, 0].unsqueeze(1) - vort_pos[:, 0]) / (r_ij2 * math.pi * 2) * (
                1.0 - (-r_ij2 / EPS ** 2).exp())
        v = v.sum(dim=1, keepdim=True)
        return torch.cat((u, v), dim=1)

    def rk4_tracer(xy0, vort, dt, length_t):
        n_steps = int(np.ceil(length_t / dt))
        h = length_t / n_steps
        for i in range(n_steps):
            uv0 = compute_tracer_uv(xy0, vort)
            xy1 = xy0 + 0.5 * h * uv0
            uv1 = compute_tracer_uv(xy1, vort)
            xy2 = xy0 + 0.5 * h * uv1
            uv2 = compute_tracer_uv(xy2, vort)
            xy3 = xy0 + h * uv2
            uv3 = compute_tracer_uv(xy3, vort)
            xy0 = xy0 + h * (uv0 + 2 * uv1 + 2 * uv2 + uv3) / 6
        return xy0

    pos_x = torch.FloatTensor(num_tracer, 1).uniform_(-0.5, 0.5)
    pos_y = torch.FloatTensor(num_tracer, 1).uniform_(-1.5, 1.5)
    xy = torch.cat((pos_x, pos_y), dim=1)
    vort_pos = torch.tensor([[0, 1],
                             [0, -1],
                             [0, 0.3],
                             [0, -0.3]
                             ])
    vort_value = torch.tensor([1, -1, 1, -1.0]).reshape(4, 1) * 0.8
    x = torch.cat((vort_pos, vort_value), dim=1)
    edge_index = torch.tensor([[j, k]
                               for j in range(4)
                               for k in range(4)
                               if j != k]).transpose(0, 1).contiguous()

    xys = [xy]
    vorts = [vort_pos]
    #net.to(device)
    net.eval()
    with torch.no_grad():
        for i in range(nt):
            print(i)
            xy = rk4_tracer(xy, x, dt, length_t)
            vort_pos = net(x, edge_index, dt, length_t)
            x = torch.cat((vort_pos, vort_value), dim=1)
            xys.append(xy)
            vorts.append(vort_pos)


    # fig, ax = plt.subplots(figsize=(8, 3.5))
    # sc1 = ax.scatter(xys[125][:, 0], xys[125][:, 1], c='#2E86C1' , s=0.001)

    # #sc2 = ax.scatter(vorts[125][:, 0], vorts[125][:, 1], c='white',edgecolors='#2980B9', linewidth='2',s=350)
    # sc2 = ax.scatter(vorts[125][:, 0], vorts[125][:, 1], c='#21618C',  s=350)
    # plt.xlim(0, 6)
    # plt.ylim(-1.8, 1.8)
    # plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False,
    #                 labelleft=False)
    # plt.axis('off')

    # def animate(i):
    #     sc1.set_offsets(xys[i])
    #     sc2.set_offsets(vorts[i])

    # #ani = animation.FuncAnimation(fig, animate, frames=len(xys), interval=5)
    # plt.show()
    return xys, vorts


#if __name__ == '__main__':
    #train()
    #test()
xys, vorts = test()

lxys = len(xys)
with open('vor.dat','w') as f:
    for j in range(nt):
        f.write('ZONE T = "zone')
        f.write('\t\t')
        f.write(str(j))
        f.write('"')
        f.write('\n')
        for i in range(num_tracer):
            xysnp = to_np(xys[j])
            f.write(str(xysnp[i,0]))
            f.write('\t\t')
            f.write(str(xysnp[i,1]))
            f.write('\n')

with open('vorp.dat','w') as f:
    for j in range(nt):
        f.write('ZONE T = "zone')
        f.write('\t\t')
        f.write(str(j))
        f.write('"')
        f.write('\n')
        for i in range(4):
            xysnp = to_np(vorts[j])
            f.write(str(xysnp[i,0]))
            f.write('\t\t')
            f.write(str(xysnp[i,1]))
            f.write('\n')
        
        
scatter_animation(xys,vorts,-1,10,-2,2,0.01)