import torch
import time
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import itertools
import os

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
FIG_DIR = os.path.join(BASE_DIR, "figs")
os.makedirs(FIG_DIR, exist_ok=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device = lambda x: None
device = torch.device("cpu")

bound = np.array([-0.5,1,-0.5,1.5]).reshape(2,2)
Re = 40
nu = 1/Re
lam = 1/(2*nu) - np.sqrt(1/(4*nu**2) + 4*np.pi**2)

def UU(X,order,prob):
    if prob == 1:
        eta = 2*np.pi
        if order == [0,0]:
            tmp = torch.zeros(X.shape[0],2)
            tmp[:,0] = 1 - torch.exp(lam*X[:,0])*torch.cos(X[:,1]*eta)
            tmp[:,1] = torch.exp(lam*X[:,0])*torch.sin(X[:,1]*eta)*lam/(eta)
            return tmp
        if order == [1,0]:
            tmp = torch.zeros(X.shape[0],2)
            tmp[:,0] = - torch.exp(lam*X[:,0])*torch.cos(X[:,1]*eta)*lam
            tmp[:,1] = torch.exp(lam*X[:,0])*torch.sin(X[:,1]*eta)*(lam**2)/(eta)
            return tmp
        if order == [0,1]:
            tmp = torch.zeros(X.shape[0],2)
            tmp[:,0] = torch.exp(lam*X[:,0])*torch.sin(X[:,1]*eta)*(eta)
            tmp[:,1] = torch.exp(lam*X[:,0])*torch.cos(X[:,1]*eta)*lam
            return tmp
        if order == [2,0]:
            tmp = torch.zeros(X.shape[0],2)
            tmp[:,0] = - torch.exp(lam*X[:,0])*torch.cos(X[:,1]*eta)*lam*lam
            tmp[:,1] = torch.exp(lam*X[:,0])*torch.sin(X[:,1]*eta)*(lam**3)/(eta)
            return tmp
        if order == [0,2]:
            tmp = torch.zeros(X.shape[0],2)
            tmp[:,0] = torch.exp(lam*X[:,0])*torch.cos(X[:,1]*eta)*(eta)**2
            tmp[:,1] = -torch.exp(lam*X[:,0])*torch.sin(X[:,1]*eta)*lam*(eta)
            return tmp

def Delta(X,prob):
    return UU(X,[2,0],prob) + UU(X,[0,2],prob)

def PP(X,order,prob):
    if prob == 1:
        if order == [0,0]:
            return 0.5*(1 - torch.exp(2*lam*X[:,0]))
        if order == [1,0]:
            return - lam*torch.exp(2*lam*X[:,0])
        if order == [0,1]:
            return 0*X[:,0]

def FF(X,prob):
    tmp = torch.zeros(X.shape[0],2)
    tmp[:,0] = -nu*Delta(X,prob)[:,0] + (UU(X,[0,0],prob)[:,0])*(UU(X,[1,0],prob)[:,0]) + \
    (UU(X,[0,0],prob)[:,1])*(UU(X,[0,1],prob)[:,0]) + PP(X,[1,0],prob)
    tmp[:,1] = -nu*Delta(X,prob)[:,1] + (UU(X,[0,0],prob)[:,0])*(UU(X,[1,0],prob)[:,1]) + \
    (UU(X,[0,0],prob)[:,1])*(UU(X,[0,1],prob)[:,1]) + PP(X,[0,1],prob)
    return tmp

class INSET():
    def __init__(self, bound, nx, prob):
        self.dim = 2
        self.hx = [(bound[0,1] - bound[0,0])/nx[0], (bound[1,1] - bound[1,0])/nx[1]]
        self.size = nx[0]*nx[1]
        self.X = torch.zeros(self.size, self.dim)
        for i in range(nx[0]):
            for j in range(nx[1]):
                self.X[i*nx[1] + j, 0] = bound[0,0] + (i + 0.5)*self.hx[0]
                self.X[i*nx[1] + j, 1] = bound[1,0] + (j + 0.5)*self.hx[1]
        self.uu = UU(self.X, [0,0], prob)[:,0:1] 
        self.vv = UU(self.X, [0,0], prob)[:,1:2] 
        self.ff = FF(self.X, prob)

class BDSET():
    def __init__(self, bound, nx, prob):
        self.dim = 2
        self.hx = [(bound[0,1] - bound[0,0])/nx[0], (bound[1,1] - bound[1,0])/nx[1]]
        self.size = 2*(nx[0] + nx[1])
        self.X = torch.zeros(self.size, self.dim)
        m = 0
        for i in range(nx[0]):
            self.X[m,0] = bound[0,0] + (i + 0.5)*self.hx[0]
            self.X[m,1] = bound[1,0] 
            m += 1
        for j in range(nx[1]):
            self.X[m,0] = bound[0,1]
            self.X[m,1] = bound[1,0] + (j + 0.5)*self.hx[1]
            m += 1
        for i in range(nx[0]):
            self.X[m,0] = bound[0,0] + (i + 0.5)*self.hx[0]
            self.X[m,1] = bound[1,1] 
            m += 1
        for j in range(nx[1]):
            self.X[m,0] = bound[0,0]
            self.X[m,1] = bound[1,0] + (j + 0.5)*self.hx[1]
            m += 1
        self.uu = UU(self.X, [0,0], prob)[:,0:1] 
        self.vv = UU(self.X, [0,0], prob)[:,1:2] 

class TESET():
    def __init__(self, bound, nx, prob):
        self.bound = bound
        self.nx = nx
        self.hx = [(self.bound[0,1]-self.bound[0,0])/self.nx[0],
                   (self.bound[1,1]-self.bound[1,0])/self.nx[1]]
        self.prob = prob
        self.size = (self.nx[0] + 1)*(self.nx[1] + 1)
        self.X = torch.zeros(self.size,2)
        m = 0
        for i in range(self.nx[0] + 1):
            for j in range(self.nx[1] + 1):
                self.X[m,0] = self.bound[0,0] + i*self.hx[0]
                self.X[m,1] = self.bound[1,0] + j*self.hx[1]
                m += 1
        self.uu = UU(self.X, [0,0], prob)[:,0:1] 
        self.vv = UU(self.X, [0,0], prob)[:,1:2] 
        self.pp = PP(self.X, [0,0], prob) 

class CauchyActivation(nn.Module):
    def __init__(self, hidden_dim):
        super(CauchyActivation, self).__init__()
        self.lambda1 = nn.Parameter(torch.randn(hidden_dim) * 0.001 + 0.05)  
        self.lambda2 = nn.Parameter(torch.randn(hidden_dim) * 0.001 + 0.0)        
        self.d = nn.Parameter(torch.abs(torch.randn(hidden_dim) * 0.001 + 0.1))
    def forward(self, x):
        denom = torch.square(x) + torch.square(self.d)
        return (self.lambda1 * x + self.lambda2) / denom

class SimpleLinear_XNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleLinear_XNet, self).__init__()
        self.dense = nn.Linear(input_dim, output_dim, bias=True)
        self.cauchy_activation = CauchyActivation(output_dim)
    def forward(self, x):
        x = self.dense(x)
        return self.cauchy_activation(x)

class Linear_XNet(nn.Module):
    def __init__(self, input_dim):
        super(Linear_XNet, self).__init__()
        self.linear = nn.Linear(input_dim, 1, bias=True)
    def forward(self, x):
        return self.linear(x)

class XNet(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=50):
        super(XNet, self).__init__()
        self.SimpleLinear_XNet = SimpleLinear_XNet(input_dim, hidden_dim)
        self.Linear_XNet = Linear_XNet(hidden_dim)
    def forward(self, x):
        x = self.SimpleLinear_XNet(x)
        x = self.Linear_XNet(x)
        return x

class MLP(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=100, depth=4, output_dim=1):
        super(MLP, self).__init__()
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.Tanh())
        for _ in range(depth-1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.Tanh())
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)

def pred_u(netu, X): return netu(X)
def pred_v(netv, X): return netv(X)
def pred_p(netp, X): return netp(X)

def loadcuda(netu, netv, netp, inset, bdset, teset):    
    for net in [netu, netv, netp]: net.to(device)
    for D in [inset, bdset, teset]:
        for k,v in D.__dict__.items():
            if isinstance(v, torch.Tensor):
                v = v.to(device)
                if k == "X" and D is inset:   
                    v.requires_grad_(True)
                setattr(D,k,v)

def Loss(netu, netv, netp, inset, bdset):
    inset.u = pred_u(netu, inset.X)
    inset.v = pred_v(netv, inset.X)
    inset.p = pred_p(netp, inset.X)
    u_x, = torch.autograd.grad(inset.u, inset.X, create_graph=True,
                               grad_outputs=torch.ones(inset.size,1).to(device))
    u_xx, = torch.autograd.grad(u_x[:,0:1], inset.X, create_graph=True,
                               grad_outputs=torch.ones(inset.size,1).to(device))
    u_yy, = torch.autograd.grad(u_x[:,1:2], inset.X, create_graph=True,
                               grad_outputs=torch.ones(inset.size,1).to(device))
    u_lap = u_xx[:,0:1] + u_yy[:,1:2]
    v_x, = torch.autograd.grad(inset.v, inset.X, create_graph=True,
                               grad_outputs=torch.ones(inset.size,1).to(device))
    v_xx, = torch.autograd.grad(v_x[:,0:1], inset.X, create_graph=True,
                               grad_outputs=torch.ones(inset.size,1).to(device))
    v_yy, = torch.autograd.grad(v_x[:,1:2], inset.X, create_graph=True,
                               grad_outputs=torch.ones(inset.size,1).to(device))
    v_lap = v_xx[:,0:1] + v_yy[:,1:2]
    p_x, = torch.autograd.grad(inset.p, inset.X, create_graph=True,
                               grad_outputs=torch.ones(inset.size,1).to(device))
    res_u = (-nu*u_lap + inset.u*u_x[:,0:1] + inset.v*u_x[:,1:2] + p_x[:,0:1] - inset.ff[:,0:1])**2
    res_v = (-nu*v_lap + inset.u*v_x[:,0:1] + inset.v*v_x[:,1:2] + p_x[:,1:2] - inset.ff[:,1:2])**2
    res_div = (u_x[:,0:1] + v_x[:,1:2])**2
    pde_loss = torch.mean(res_u + res_v + res_div)
    bc_u_loss = torch.mean((pred_u(netu, bdset.X) - bdset.uu)**2)
    bc_v_loss = torch.mean((pred_v(netv, bdset.X) - bdset.vv)**2)
    bc_loss = bc_u_loss + bc_v_loss
    total_loss = 0.5 * pde_loss + bc_loss
    return total_loss, pde_loss, bc_loss

def calculate_mse(netu, netv, netp, teset):
    with torch.no_grad():
        u_pred = pred_u(netu, teset.X)
        v_pred = pred_v(netv, teset.X)
        p_pred = pred_p(netp, teset.X)
        
        u_mse = torch.mean((u_pred - teset.uu)**2)
        v_mse = torch.mean((v_pred - teset.vv)**2)
        p_mse = torch.mean((p_pred - teset.pp.unsqueeze(1))**2)
        total_mse = u_mse + v_mse + p_mse
        
        return total_mse, u_mse, v_mse, p_mse

def train_LBFGS(netu, netv, netp, inset, bdset, teset, epochs=50, lr=0.5):
    optim = torch.optim.LBFGS(itertools.chain(netu.parameters(),netv.parameters(),netp.parameters()),
                              lr=lr, max_iter=100, tolerance_grad=1e-14, tolerance_change=1e-14,
                              history_size=2500,line_search_fn='strong_wolfe')
    print("Training..."); start_time = time.time()
    for epoch in range(epochs):
        def closure():
            optim.zero_grad(); total_loss,_,_ = Loss(netu, netv, netp, inset, bdset)
            total_loss.backward(); return total_loss
        total_loss = optim.step(closure)
        if epoch % 10 == 0:
            total_loss_display, pde_loss_display, bc_loss_display = Loss(netu, netv, netp, inset, bdset)
            elapsed = time.time()-start_time
            print(f"Epoch {epoch}, Loss: {total_loss_display.item():.4e}, "
                  f"PDE: {pde_loss_display.item():.4e}, BC: {bc_loss_display.item():.4e}, Time: {elapsed:.2f}s")
    return

def train_adam(netu, netv, netp, inset, bdset, teset, epochs=2000, lr=1e-2):
    optimizer = torch.optim.Adam(list(netu.parameters()) + 
                                 list(netv.parameters()) + 
                                 list(netp.parameters()), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=250, gamma=0.5)
    
    loss_history, mse_history = [], []
    print("Training..."); start_time = time.time()
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        total_loss, pde_loss, bc_loss = Loss(netu, netv, netp, inset, bdset)
        total_loss.backward()
        optimizer.step()
        scheduler.step()
        
        loss_history.append(total_loss.item())
        
        if epoch % 100 == 0:
            total_mse, u_mse, v_mse, p_mse = calculate_mse(netu, netv, netp, teset)
            mse_history.append([total_mse.item(), u_mse.item(), v_mse.item(), p_mse.item()])
            
            current_lr = optimizer.param_groups[0]['lr']
            elapsed = time.time() - start_time
            print(f"Epoch {epoch}, Loss: {total_loss.item():.4e}, "
                  f"PDE: {pde_loss.item():.4e}, BC: {bc_loss.item():.4e}, "
                  f"MSE: {total_mse.item():.4e}, LR: {current_lr:.2e}, Time: {elapsed:.2f}s")
    
    print(f"Time: {time.time()-start_time:.2f}s")
    return loss_history, mse_history

def calculate_diagnostics(netu, netv, netp, inset, bdset, teset):
    with torch.no_grad():
        u_pred = pred_u(netu, teset.X); v_pred = pred_v(netv, teset.X); p_pred = pred_p(netp, teset.X)
        u_true, v_true, p_true = teset.uu, teset.vv, teset.pp.unsqueeze(1)

        u_mse = torch.mean((u_pred - u_true)**2)
        v_mse = torch.mean((v_pred - v_true)**2)
        p_mse = torch.mean((p_pred - p_true)**2)
        total_mse = u_mse + v_mse + p_mse

        u_rel = torch.norm(u_pred - u_true) / torch.norm(u_true)
        v_rel = torch.norm(v_pred - v_true) / torch.norm(v_true)

    total_loss, pde_loss, bc_loss = Loss(netu, netv, netp, inset, bdset)
    pde_rms = torch.sqrt(pde_loss)

    return {
        "u_mse": u_mse.item(),
        "v_mse": v_mse.item(),
        "p_mse": p_mse.item(),
        "total_mse": total_mse.item(),
        "u_rel": u_rel.item(),
        "v_rel": v_rel.item(),
        "pde_rms": pde_rms.item(),
        "bc_loss": bc_loss.item()
    }

nx = [80,80]; nx_te = [15,15]; prob = 1
inset = INSET(bound,nx,prob); bdset = BDSET(bound,[n*2 for n in nx],prob); teset = TESET(bound,nx_te,prob)

netu,netv,netp = XNet(2,100),XNet(2,100),XNet(2,100)
loadcuda(netu,netv,netp,inset,bdset,teset)
train_adam(netu, netv, netp, inset, bdset, teset, epochs=1000, lr=1e-2)

final_diag = calculate_diagnostics(netu, netv, netp, inset, bdset, teset)


mlpu, mlpv, mlpp = MLP(2,100,4,1), MLP(2,100,4,1), MLP(2,100,4,1)
for D in [inset,bdset,teset]:
    for k,v in D.__dict__.items(): 
        if isinstance(v, torch.Tensor): setattr(D,k,v.type(torch.float32))
loadcuda(mlpu, mlpv, mlpp, inset, bdset, teset)
train_adam(mlpu, mlpv, mlpp, inset, bdset, teset, epochs=1000, lr=1e-2)

final_diag = calculate_diagnostics(mlpu, mlpv, mlpp, inset, bdset, teset)

nx_te_in = [64,64]
x_train = np.linspace(bound[0,0],bound[0,1],nx_te_in[0])
y_train = np.linspace(bound[1,0],bound[1,1],nx_te_in[1])
x0, x1 = np.meshgrid(x_train,y_train)
xx = np.hstack((x0.reshape(-1,1), x1.reshape(-1,1)))
xx = torch.from_numpy(xx).type(torch.float32)

with torch.no_grad():
    u_acc = UU(xx,[0,0],prob)[:,0:1].numpy()
    v_acc = UU(xx,[0,0],prob)[:,1:2].numpy()
    u_mlp = pred_u(mlpu, xx).cpu().numpy(); v_mlp = pred_v(mlpv, xx).cpu().numpy()
    u_xnet = pred_u(netu, xx).cpu().numpy(); v_xnet = pred_v(netv, xx).cpu().numpy()

u_diff_mlp = u_mlp - u_acc; v_diff_mlp = v_mlp - v_acc
u_diff_xnet = u_xnet - u_acc; v_diff_xnet = v_xnet - v_acc

fig, ax = plt.subplots(2,3,figsize=(14,6))
x0f = x0.flatten(); x1f = x1.flatten(); num_line = 20

c0 = ax[0,0].tricontourf(x0f,x1f,u_acc.flatten(),num_line,cmap="rainbow")
fig.colorbar(c0,ax=ax[0,0]); ax[0,0].set_title("Exact u")
c3 = ax[1,0].tricontourf(x0f,x1f,v_acc.flatten(),num_line,cmap="rainbow")
fig.colorbar(c3,ax=ax[1,0]); ax[1,0].set_title("Exact v")

c1 = ax[0,1].tricontourf(x0f,x1f,u_diff_mlp.flatten(),num_line,cmap="rainbow")
fig.colorbar(c1,ax=ax[0,1]); ax[0,1].set_title("MLP Error u")
c4 = ax[1,1].tricontourf(x0f,x1f,v_diff_mlp.flatten(),num_line,cmap="rainbow")
fig.colorbar(c4,ax=ax[1,1]); ax[1,1].set_title("MLP Error v")

c2 = ax[0,2].tricontourf(x0f,x1f,u_diff_xnet.flatten(),num_line,cmap="rainbow")
fig.colorbar(c2,ax=ax[0,2]); ax[0,2].set_title("XNet Error u")
c5 = ax[1,2].tricontourf(x0f,x1f,v_diff_xnet.flatten(),num_line,cmap="rainbow")
fig.colorbar(c5,ax=ax[1,2]); ax[1,2].set_title("XNet Error v")

plt.suptitle("Exact vs MLP vs XNet (K-flow)",fontsize=16)
plt.tight_layout()
fig.savefig(os.path.join(FIG_DIR,"compare_baseline_xnet.png"),dpi=300)
plt.close(fig)

print(":", os.path.join(FIG_DIR,"compare_baseline_xnet.png"))