import torch
from torch import nn
import matplotlib.pyplot as plt
import torch.nn.functional as func
import numpy as np
import math
import h5py
import os
import scipy.io

Hall = []
epsT = 0.01
epsN = 0.01
epsD = 0.01
nsteps = 2
n_particle = 2
dimension = 2
N = n_particle * dimension
device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
plot_points = 100
lt = 50
dtp = torch.tensor([lt / plot_points])
hidden_dim = 64

def to_np(x):
    return x.detach().cpu().numpy()




class KAnalysis(nn.Module):
    def __init__(self, ):
        super(KAnalysis, self).__init__()
        #self.cal_H = torch.log
    def forward(self,q,p,gamma):
        n_particles = q.shape[2]
        with torch.enable_grad():
            q = q.requires_grad_(True)
            p = p.requires_grad_(True)
            q1 = q.transpose(1, 2)
            p1 = p.transpose(1, 2)
            r = torch.sqrt((q-q1)**2 + (p-p1)**2+torch.eye(n_particles).unsqueeze(0))#.to(device)
            r = r.unsqueeze(-1)
            K = torch.log(r)/(4*math.pi)
            K = K * (gamma.unsqueeze(-1))*(gamma.squeeze(1).unsqueeze(-1).unsqueeze(-1))
            K = K.sum()
            dq = -torch.autograd.grad(K, q, retain_graph=True, create_graph=False)[0]
            dp = -torch.autograd.grad(K, p, retain_graph=True, create_graph=False)[0]
            #print(dq)
        return dq, dp
class LinearBlock(nn.Module):
    def __init__(self, inchannel, outchannel):
        super(LinearBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Linear(inchannel, outchannel),
            #nn.Tanh(),
            nn.Sigmoid(),
            #nn.ReLU(inplace=True),
        )

    def forward(self, x):
        out = self.left(x)
        return out    
class KTrained(nn.Module):
    def __init__(self, hidden_dim):
        super(KTrained, self).__init__()
        self.cal_H = nn.Sequential(LinearBlock(2, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    nn.Linear(hidden_dim, 1))
        #self.cal_K = torch.log
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.uniform_(-math.sqrt(6. / m.in_features), math.sqrt(6. / m.in_features))

    def forward_train(self,q,p,gamma):
        #n_particles = q.shape[2]
        with torch.enable_grad():
            q = q.requires_grad_(True)
            p = p.requires_grad_(True)
            q1 = q.transpose(1, 2)
            p1 = p.transpose(1, 2)
            r = torch.cat([(q-q1).unsqueeze(-1),(p-p1).unsqueeze(-1)], dim=3)
            #r = torch.sqrt((q-q1)**2 + (p-p1)**2+torch.eye(n_particles).unsqueeze(0).to(device))
            #r = r.unsqueeze(-1)
            K = self.cal_H(r)
            K = K * (gamma.unsqueeze(-1))*(gamma.squeeze(1).unsqueeze(-1).unsqueeze(-1))
            K = K.sum()
            dq = -torch.autograd.grad(K, q, retain_graph=True, create_graph=True)[0]
            dp = -torch.autograd.grad(K, p, retain_graph=True, create_graph=True)[0]
        return dq, dp

    def forward(self,q,p,gamma):
        #n_particles = q.shape[2]
        with torch.enable_grad():
            q = q.requires_grad_(True)
            p = p.requires_grad_(True)
            q1 = q.transpose(1, 2)
            p1 = p.transpose(1, 2)
            r = torch.cat([(q-q1).unsqueeze(-1),(p-p1).unsqueeze(-1)], dim=3)
            #print(r.shape)
            #r = torch.sqrt((q-q1)**2 + (p-p1)**2+torch.eye(n_particles).unsqueeze(0).to(device))
            #r = r.unsqueeze(-1)
            #Hall.append(r[0,1,0,0])
            K = self.cal_H(r)
            #Hall.append(K[0,1,0,0])
            K = K * (gamma.unsqueeze(-1))*(gamma.squeeze(1).unsqueeze(-1).unsqueeze(-1))
            K = K.sum()
            dq = -torch.autograd.grad(K, q, retain_graph=True, create_graph=False)[0]
            dp = -torch.autograd.grad(K, p, retain_graph=True, create_graph=False)[0]
            torch.cuda.empty_cache()
            del q, p, q1, p1, r, K
        return dq, dp



def Nonsep_SymInt(q, p, x, y, dt, K_t, eps, gamma):
    n_steps = np.round((torch.abs(dt) / eps).max().item())
    h = dt / n_steps
    h = h.unsqueeze(1).unsqueeze(1)
    w = 20
    for i_step in range(int(n_steps)):
        x1, p1 = K_t(q, y, gamma)
        p = p - x1 * h * 0.5/gamma
        x = x + p1 * h * 0.5/gamma
        q1, y1 = K_t(x, p, gamma)
        q = q + y1 * h * 0.5/gamma
        y = y - q1 * h * 0.5/gamma
        q1 = 0.5 * (q - x)
        p1 = 0.5 * (p - y)
        x1 = torch.cos(2 * w * h/gamma) * q1 + torch.sin(2 * w * h/gamma) * p1
        y1 = -torch.sin(2 * w * h/gamma) * q1 + torch.cos(2 * w * h/gamma) * p1
        q1 = 0.5 * (q + x)
        p1 = 0.5 * (p + y)
        q = q1 + x1
        p = p1 + y1
        x = q1 - x1
        y = p1 - y1
        q1, y1 = K_t(x, p, gamma)
        q = q + y1 * h * 0.5/gamma
        y = y - q1 * h * 0.5/gamma
        x1, p1 = K_t(q, y, gamma)
        p = p - x1 * h * 0.5/gamma
        x = x + p1 * h * 0.5/gamma
    return q, p, x, y


def Gen_Data():
    datau = [[] for i in range(nsteps+1)]
    datat = []
    n_sample = 2048
    for i in range(n_sample):
        t1_t2 = torch.tensor([epsD])
        q0 = 40.* torch.rand(1, 1, n_particle)
        p0 = 40.* torch.rand(1, 1, n_particle)
        while torch.sqrt((q0[0,0,0]-p0[0,0,0])**2+(q0[0,0,1]-p0[0,0,1])**2)<0.2:
            p0 = 40.* torch.rand(1, 1, n_particle)
        gamma = 2*torch.rand(1, 1, n_particle)-1
        while torch.abs(gamma[0,0,0])<0.5 or torch.abs(gamma[0,0,1])<0.5:
            gamma = 2*torch.rand(1, 1, n_particle)-1
        #print(torch.sqrt((q0[0,0,0]-p0[0,0,0])**2+(q0[0,0,1]-p0[0,0,1])**2),gamma)        
        x0 = q0
        y0 = p0
        datau[0].append(torch.cat([q0, p0, q0, p0, gamma], dim=1).unsqueeze(-1))
        f_true =  KAnalysis()
        f_true.eval()
        with torch.no_grad():
            for j in range(nsteps):      
                q1, p1, x1, y1 = Nonsep_SymInt(q0, p0, x0, y0, t1_t2, f_true.forward, epsT, gamma)
                q0 = q1
                p0 = p1
                x0 = q1
                y0 = p1
                #print(q1-x1,p1-y1)
                datau[j+1].append(torch.cat([q1, p1, q1, p1,gamma], dim=1).unsqueeze(-1))
            datat.append(t1_t2)
    data_root = os.path.join(os.path.dirname(os.path.realpath(__file__)))
    datau = [torch.cat(datau[j]) for j in range(nsteps+1)]
    datau = torch.cat(datau, dim = -1).float()
    datat = torch.tensor(datat).float()
    #print(datau.shape)
    #print(datat.shape)
    hf = h5py.File(os.path.join(data_root, "data.h5"), "w")
    hf.create_dataset('u', data=datau)
    hf.create_dataset('dt', data=datat)
    hf.close() 
    qT = torch.tensor([[[0.,0.,0.,0.]]])
    pT = torch.tensor([[[1.,-1,0.3,-0.3]]])
    xT = torch.tensor([[[0.,0.,0.,0.]]])
    yT = torch.tensor([[[1.,-1,0.3,-0.3]]])
    gamma = torch.tensor([[[1., -1., 1., -1.]]]) * 0.8
    qpxyT = [torch.cat([qT, pT, xT, yT,gamma], dim=1)]
    print('test data')
    f_true =  KAnalysis()
    f_true.eval()
    with torch.no_grad():
        for i in range(plot_points):
            qT, pT, xT, yT = Nonsep_SymInt(qT, pT, xT, yT, dtp, f_true.forward, epsT, gamma)
            #print(qT,pT,qT-xT,pT-yT)
            qpxyT.append(torch.cat([qT, pT, xT, yT,gamma], dim=1))
    #print(torch.cat(qpxyT).shape)
    torch.save(torch.cat(qpxyT), 'test.dat')
    


class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_type):
        f = h5py.File('data.h5')
        self.u = f['u'][:]
        self.dt = f['dt'][:]
        split = int(self.u.shape[0] * 0.9)
        if data_type == 'train':
            self.u = torch.from_numpy(self.u[:split]).to(device)
            self.dt = torch.from_numpy(self.dt[:split]).to(device)
        else:
            self.u = torch.from_numpy(self.u[split:]).to(device)
            self.dt = torch.from_numpy(self.dt[split:]).to(device)
        f.close()

    def __getitem__(self, index):
        return self.u[index], self.dt[index]

    def __len__(self):
        return self.u.shape[0]



f_neur = KTrained(hidden_dim)
f_neur.to(device)
loss_func = func.l1_loss
#loss_func = func.mse_loss


def train():
    #writer = SummaryWriter()
    optimizer = torch.optim.Adam(f_neur.parameters(), lr=5.*epsN)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
    train_data_loader = torch.utils.data.DataLoader(Dataset('train'), batch_size=256, shuffle=True)
    test_data_loader = torch.utils.data.DataLoader(Dataset('test'), batch_size=256, shuffle=True)
    lowest_test_loss = 99999
    for i in range(100):
        train_loss = 0
        train_sample = 0
        test_loss = 0
        test_sample = 0
        f_neur.train()
        for batch_index, data_batch in enumerate(train_data_loader):
            optimizer.zero_grad()
            qpxyT, dt = data_batch
            #print(qpxyT.shape,dt.shape)
            qpxyN = []
            for j in range(nsteps):
                q0T = qpxyT[:, 0:1, :, j]
                p0T = qpxyT[:, 1:2, :, j]
                x0T = qpxyT[:, 2:3, :, j]
                y0T = qpxyT[:, 3:4, :, j]
                gamma = qpxyT[:, 4:5, :, j]
                q1N, p1N, x1N, y1N = Nonsep_SymInt(q0T, p0T, x0T, y0T, dt, f_neur.forward_train, epsN, gamma)
                qpxyN.append(torch.cat([q1N, p1N, x1N, y1N], dim=1).unsqueeze(-1))            
            qpxyN = torch.cat(qpxyN,dim = -1)
            #print(qpxyT.shape)
            #print(qpxyN.shape)
            loss = loss_func(qpxyT[:,0:4,:,1:], qpxyN)
            train_loss += loss.detach().cpu().item()
            #print(loss.detach().cpu().item())
            train_sample += 1
            loss.backward()
            optimizer.step()
        scheduler.step()
        f_neur.eval()
        with torch.no_grad():
            for batch_index, data_batch in enumerate(test_data_loader):
                   qpxyT, dt = data_batch
                   qpxyN = []
                   for j in range(nsteps):
                       q0T = qpxyT[:, 0:1, :, j]
                       p0T = qpxyT[:, 1:2, :, j]
                       x0T = qpxyT[:, 2:3, :, j]
                       y0T = qpxyT[:, 3:4, :, j]
                       gamma = qpxyT[:, 4:5, :, j]
                       q1N, p1N, x1N, y1N = Nonsep_SymInt(q0T, p0T, x0T, y0T, dt, f_neur.forward_train, epsN,gamma)
                       qpxyN.append(torch.cat([q1N, p1N, x1N, y1N], dim=1).unsqueeze(-1))            
                   qpxyN = torch.cat(qpxyN,dim = -1)
                   loss = loss_func(qpxyT[:,0:4,:,1:], qpxyN)
                   test_loss += loss.detach().cpu().item()
                   test_sample += 1
            print(i, train_loss / (train_sample), test_loss / test_sample)
        
        if lowest_test_loss > +test_loss / test_sample:
            torch.save(f_neur.state_dict(), "model.pt")
            lowest_test_loss = test_loss / test_sample


def test():
    qpxyT = torch.load('test.dat')
    qN = qpxyT[0:1,0:1,:].to(device)
    pN = qpxyT[0:1,1:2,:].to(device)
    xN = qpxyT[0:1,2:3,:].to(device)
    yN = qpxyT[0:1,3:4,:].to(device)
    gamma = qpxyT[0:1,4:5,:].to(device)
    qpxyN = [torch.cat([qN, pN, xN, yN], dim=1).detach().cpu()]
    f_neur.eval()
    with torch.no_grad():
        for i in range(plot_points):
            qN, pN, xN, yN = Nonsep_SymInt(qN, pN, xN, yN, dtp.to(device), f_neur.forward, epsN,gamma)
            qpxyN.append(torch.cat([qN, pN, xN, yN], dim=1).detach().cpu())
    qpxyN = torch.cat(qpxyN)
    qpxyN = to_np(qpxyN)
    qpxyT = to_np(qpxyT)
    plt.clf()
    plt.scatter(qpxyN[:, 0, 0], qpxyN[:, 1, 0], s=80, c="r", marker="s")
    plt.scatter(qpxyN[:, 0, 1], qpxyN[:, 1, 1], s=80, c="r", marker="s")
    plt.scatter(qpxyN[:, 0, 2], qpxyN[:, 1, 2], s=80, c="r", marker="s")
    plt.scatter(qpxyN[:, 0, 3], qpxyN[:, 1, 3], s=80, c="r", marker="s")
    plt.plot(qpxyT[:, 0, 0], qpxyT[:, 1, 0], c="b")
    plt.plot(qpxyT[:, 0, 1], qpxyT[:, 1, 1], c="b")
    plt.plot(qpxyT[:, 0, 2], qpxyT[:, 1, 2], c="b")
    plt.plot(qpxyT[:, 0, 3], qpxyT[:, 1, 3], c="b")
    #plt.scatter(qpxyT[:, 0, 0], qpxyT[:, 1, 0], s=40, c="b")
    plt.draw()
    plt.show()
    plt.pause(1.)
    plt.savefig('out')

def validation():
    f_neur.load_state_dict(torch.load("model.pt", map_location=lambda storage, location: storage))
    f_neur.to(device)
    f_neur.eval()
    qpxyT = torch.load('test.dat')
    qN = qpxyT[0:1,0:1,:].to(device)
    pN = qpxyT[0:1,1:2,:].to(device)
    xN = qpxyT[0:1,2:3,:].to(device)
    yN = qpxyT[0:1,3:4,:].to(device)
    gamma = qpxyT[0:1,4:5,:].to(device)
    qpxyN = [torch.cat([qN, pN, xN, yN], dim=1).detach().cpu()]
    with torch.no_grad():
        for i in range(plot_points):
            qN, pN, xN, yN = Nonsep_SymInt(qN, pN, xN, yN, dtp.to(device), f_neur.forward, epsN,gamma)
            qpxyN.append(torch.cat([qN, pN, xN, yN], dim=1).detach().cpu())
    qpxyN = torch.cat(qpxyN)
    qpxyN = to_np(qpxyN)
    qpxyT = to_np(qpxyT)
    plt.clf()
    plt.scatter(qpxyN[:, 0, 0], qpxyN[:, 1, 0], s=80, c="r", marker="s")
    plt.scatter(qpxyN[:, 0, 1], qpxyN[:, 1, 1], s=80, c="r", marker="s")
    plt.scatter(qpxyN[:, 0, 2], qpxyN[:, 1, 2], s=80, c="r", marker="s")
    plt.scatter(qpxyN[:, 0, 3], qpxyN[:, 1, 3], s=80, c="r", marker="s")
    plt.plot(qpxyT[:, 0, 0], qpxyT[:, 1, 0], c="b")
    plt.plot(qpxyT[:, 0, 1], qpxyT[:, 1, 1], c="b")
    plt.plot(qpxyT[:, 0, 2], qpxyT[:, 1, 2], c="b")
    plt.plot(qpxyT[:, 0, 3], qpxyT[:, 1, 3], c="b")
    plt.scatter(qpxyT[:, 0, 0], qpxyT[:, 1, 0], s=40, c="b")
    plt.scatter(qpxyT[:, 0, 1], qpxyT[:, 1, 1], s=40, c="b")
    plt.scatter(qpxyT[:, 0, 2], qpxyT[:, 1, 2], s=40, c="b")
    plt.scatter(qpxyT[:, 0, 3], qpxyT[:, 1, 3], s=40, c="b")
    plt.draw()
    plt.show()
    plt.pause(1.)
    plt.savefig('out')
    with torch.no_grad():
        with open('qp.dat','w') as f:
            for i in range(plot_points):
                f.write(str(qpxyN[i, 0, 0]))
                f.write('\t\t')
                f.write(str(qpxyN[i, 1, 0]))
                f.write('\t\t')
                f.write(str(qpxyT[i, 0, 0]))
                f.write('\t\t')
                f.write(str(qpxyT[i, 1, 0]))
                f.write('\n')

def output_cal_H():
    f_neur.load_state_dict(torch.load("model.pt", map_location=lambda storage, location: storage))
    f_neur.to(device)
    f_neur.eval()
    rpoint = 1001
    for i in range(1,rpoint+1):
        print(i)
        t1_t2 = torch.tensor([epsN]).to(device)
        q0 = torch.tensor([[[0.,0.]]]).to(device)
        p0 = torch.tensor([[[0.,i*0.01]]]).to(device)
        x0 = torch.tensor([[[0.,0.]]]).to(device)
        y0 = torch.tensor([[[0.,i*0.01]]]).to(device)
        gamma = torch.tensor([[[1., 1.]]]).to(device)
        with torch.no_grad():
            q1, p1, x1, y1 = Nonsep_SymInt(q0, p0, x0, y0, t1_t2, f_neur.forward, epsN, gamma)
    with open('rH.dat','w') as f:
        for i in range(rpoint):
            r = to_np(Hall[8*i])
            H = to_np(Hall[8*i+1])
            f.write(str(r))
            f.write('\t\t')
            f.write(str(H))
            f.write('\n')
            
def Taylor():
    num_x = 25
    num_y = 25
    lx = 3.
    ly = lx * num_y / num_x
    dx = lx / (num_x - 1.)
    dy = ly / (num_y - 1.)
    q0 = torch.tensor([[[i * dx-lx/2.
                         for i in range(num_x)
                         for j in range(num_y)]]]).to(device)
    p0 = torch.tensor([[[j * dy-ly/2.
                         for i in range(num_x)
                         for j in range(num_y)]]]).to(device)
    x0 = q0
    y0 = p0
    r1 = (q0-0.4)**2+p0**2
    r2 = (q0+0.4)**2+p0**2
    gamma = 1./0.3 * (2.-r1/0.09)*torch.exp(0.5*(1-r1/0.09));
    gamma = gamma + 1./0.3 * (2.-r2/0.09)*torch.exp(0.5*(1-r2/0.09))
    gamma = 2*gamma * dx * dy 
    f_neur.load_state_dict(torch.load("model.pt", map_location=lambda storage, location: storage))
    f_neur.to(device)
    f_neur.eval()
    rpoint = 301
    for i in range(rpoint):
        print(i)
        t1_t2 = torch.tensor([epsN]).to(device)
        xlist = q0[0,0,:].tolist() 
        ylist = p0[0,0,:].tolist()
        wlist = gamma[0,0,:].tolist()
        da=math.floor((i%100000)/10000)
        db=math.floor((i%10000)/1000)
        dc=math.floor((i%1000)/100)
        dd=math.floor((i%100)/10)
        de=math.floor((i%10)/1)
        name = ['xyw',str(da),str(db),str(dc),str(dd),str(de),'.mat']
        name = ''.join(name)
        #scipy.io.savemat(name,mdict={'x':xlist,'y':ylist,'w':wlist})
        with torch.no_grad():
            q0, p0, x0, y0 = Nonsep_SymInt(q0, p0, x0, y0, t1_t2, f_neur.forward, epsN, gamma)
        
        
        if i % 100 == 0:
            q0np = to_np(q0)
            p0np = to_np(p0)
            gammanp = to_np(gamma)
            plt.clf()
            plt.scatter(q0np[0, 0, :], p0np[:, 0, :], s=80, c=gammanp[0,0,:], marker="s")
            plt.draw()
            plt.show()
            plt.pause(1.)
            plt.savefig('out')

if __name__ == '__main__':
    Gen_Data()
    train()
    Taylor()

    
