# -*- coding: utf-8 -*-
"""
Created on Thu May  7 05:22:09 2020

@author: syxiong
"""
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import math
import struct

#for i in range(ldata):
#    plt.clf()    
#    plt.contourf(meshx, meshy, data[i], cmap='RdYlBu', levels=levels, extend='both')
#    plt.xlim(0., lx)
#    plt.ylim(0., lx)
#    plt.colorbar()
#    plt.show()
#    plt.pause(1)
def to_np(x):
    return x.detach().cpu().numpy()
class cal_drw(nn.Module):
    def __init__(self, device='cpu'):
        super(cal_drw, self).__init__()
        self.device = device
    def forward(self,r0,w0):
        nbatch = r0.shape[0]
        nparticle = r0.shape[1]
        dimension = r0.shape[2]
        r0 = r0.reshape(nbatch,1,-1,dimension)
        r1 = r0.transpose(1, 2)
        rr= torch.norm(r1 - r0, dim=-1, keepdim=True) ** 2
        I=torch.eye(nparticle).unsqueeze(-1).unsqueeze(0).expand(nbatch,nparticle,nparticle,1).to(self.device)
        rt = (r1 - r0)/(rr+I)
        u = -torch.matmul(rt[:,:,:,1],w0).reshape(nbatch,-1,1).to(self.device)
        v = torch.matmul(rt[:,:,:,0],w0).reshape(nbatch,-1,1).to(self.device)
        return torch.cat([u,v], dim=2)/2./math.pi,w0*0.

def RK4(r0, w0, dt, lt,cal):
    n_steps = int(np.ceil(lt / dt))
    h = lt / n_steps
    r1 = r0
    w1 = w0
    for i in range(n_steps):
        dr1,dw1 = cal(r1,w1)
        dr2,dw2 = cal(r1+h/2*dr1,w1+h/2*dw1)
        dr3,dw3 = cal(r1+h/2*dr2,w1+h/2*dw2)
        dr4,dw4 = cal(r1+h*dr3,w1+h*dw3)
        r1 = r1+h/6 * (dr1 + 2 * dr2 + 2 * dr3 + dr4)
        w1 = w1+h/6 * (dw1 + 2 * dw2 + 2 * dw3 + dw4)
    return r1,w1

def scatter_animation(pos,vor,minx=-1,maxx=25,miny=-2,maxy=2,pt=0.01):
    plt.figure(figsize=(10, 8))  
    lx = 2. * math.pi
    mx = 2.*math.pi/3.
    plt.xlim(0.+mx, lx-mx)
    plt.ylim(0.+mx, lx-mx)
    lenp = len(pos)
    for i in range(lenp):
        posi = pos[i]
        vori = vor[i]
        posi = to_np(posi)
        vori = to_np(vori)
        plt.scatter(posi[:,0], posi[:,1],c = vori[:,0])
        plt.show()
        plt.pause(pt)
    Ngrid = 1024
    f = open('tecplot_vor_wo_force.dat', 'rb')
    struct.unpack('f'*(57),f.read(4*57))
    meshx = struct.unpack('f'*(Ngrid*Ngrid),f.read(4*Ngrid*Ngrid))
    meshy = struct.unpack('f'*(Ngrid*Ngrid),f.read(4*Ngrid*Ngrid))
    vor = struct.unpack('f'*(Ngrid*Ngrid),f.read(4*Ngrid*Ngrid))
    vor = np.asarray(vor).reshape(Ngrid,Ngrid)      
    f.close()   
    d = 5
    levels = np.arange(-50, 50+d, d)
    meshx = np.asarray(meshx).reshape(Ngrid,Ngrid)
    meshy = np.asarray(meshy).reshape(Ngrid,Ngrid)
    plt.contourf(meshx, meshy, vor, cmap='RdYlBu', levels=levels, extend='both',alpha=0.1)
    plt.colorbar()
    plt.show()
def test_leapfrog():
    device = 'cuda:0'
    device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
    pos = torch.tensor([2.741593,2.541593,3.541593,2.541593,3.541593,3.741593,2.741593,3.741593]).reshape(1, 4, 2).to(device)
    vor = 0.75*torch.tensor([1,-1,1.,-1]).reshape(1, 4, 1).to(device)  
    dt = 0.01
    lt = 1
    nt = 7     
    
    pos_all = [pos.reshape(4,2)]
    vor_all = [vor.reshape(4,1)]

    for it in range(nt):
        print(it)
        pos0 = pos.reshape(1,1,-1,2)
        pos1 = pos0.transpose(1, 2)
        posr= torch.norm(pos1 - pos0, dim=-1, keepdim=True)
        I=torch.eye(4).unsqueeze(-1).unsqueeze(0).expand(1,4,4,1).to(device)
        print(torch.max(posr),torch.min(posr+4.*I))
        pos,vor = RK4(pos, vor, dt, lt, cal_drw(device))
        pos_all.append(pos.reshape(4,2))
        vor_all.append(vor.reshape(4,1))
    scatter_animation(pos_all,vor_all)
test_leapfrog()        