import torch
import numpy as np
from scipy.spatial import Delaunay
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import spsolve
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import matplotlib.tri as mtri

u_left, u_right, u_top, u_bottom = 1., 0, 0, 0
resolution = 10

def mesh_and_solve(l, h, u_bc=(u_left, u_right, u_top, u_bottom), res=resolution):
    
    u_left, u_right, u_top, u_bottom = u_bc
    # Sample 2D positions
    # nx, ny = int(l/0.25), int(h/0.25)
    if l > h:
        nx, ny = int(l/h * res), int(res)
    else:
        nx, ny = int(res), int(h/l * res)
        
    x, y = np.meshgrid(np.linspace(0, l, nx), np.linspace(0, h, ny))
    x = torch.FloatTensor(x)
    y = torch.FloatTensor(y)
    xy = torch.cat((x[None, ...],y[None, ...]), dim=0)  # shape: (num_nodes, 2)
    pos = xy
    pos[0,:, 1:-1] = pos[0,:, 1:-1] + l/(2*nx) * torch.rand_like(pos[0,:, 1:-1])  # shape: (num_nodes, 2)
    pos[1,1:-1, :] = pos[1,1:-1, :] + h/(2*ny) * torch.rand_like(pos[1,1:-1, :])  # shape: (num_nodes, 2)

    points = pos.reshape(2, -1).T  # shape: (num_nodes, 2)bio
    points = points.numpy()

    tri = Delaunay(points)
    triangles = tri.simplices
    vertices = points
    num_nodes = len(vertices)

    def k_func(x, y):
        return 1  
    
    def u_boundary(x, y):
        if np.isclose(y, 0): return u_bottom
        elif np.isclose(y, h): return u_top
        elif np.isclose(x, 0): return u_left
        elif np.isclose(x, l): return u_right
        else:
            return 0  

    def f_func(x, y):
        return 0.
        # return 5/4.*np.sin(x)*np.sin(0.5*y)  

    K = lil_matrix((num_nodes, num_nodes))
    F = np.zeros(num_nodes)

    def local_stiffness_matrix(xy, k_vals):
        x = xy[:, 0]
        y = xy[:, 1]
        area = 0.5 * np.linalg.det(np.array([
            [1, x[0], y[0]],
            [1, x[1], y[1]],
            [1, x[2], y[2]]
        ]))
        if area <= 0:
            raise ValueError("Non-positive area in triangle.")

        # Compute gradients of shape functions
        b = np.array([y[1] - y[2], y[2] - y[0], y[0] - y[1]])
        c = np.array([x[2] - x[1], x[0] - x[2], x[1] - x[0]])
        
        # Gradient matrix G = [grad_phi1, grad_phi2, grad_phi3]
        grad = np.vstack((b, c)) / (2.0 * area)

        # Average conductivity
        k_avg = np.mean(k_vals)

        # Local stiffness matrix: ke[i,j] = k_avg * area * (grad_phi_i ⋅ grad_phi_j)
        ke = k_avg * area * grad.T @ grad
        return ke, area

    for tri_nodes in triangles:
        xy = vertices[tri_nodes]

        # Evaluate conductivity at triangle vertices
        k_vals = np.array([k_func(*xy[i]) for i in range(3)])
        ke, area = local_stiffness_matrix(xy, k_vals)

        for i in range(3):
            for j in range(3):
                K[tri_nodes[i], tri_nodes[j]] += ke[i, j]

        # Load vector: evaluate f at centroid
        centroid = np.mean(xy, axis=0)
        f_centroid = f_func(*centroid)
        for i in range(3):
            F[tri_nodes[i]] += f_centroid * area / 3.0

    # === Apply Dirichlet boundary conditions ===
    boundary_nodes = np.unique(tri.convex_hull)
    u = np.zeros(num_nodes)
    bc_values = np.array([u_boundary(*vertices[i]) for i in boundary_nodes])
    u[boundary_nodes] = bc_values

    K = K.tocsr()
    F -= K @ u  # Adjust RHS for known Dirichlet values

    # Modify system to enforce Dirichlet BCs strongly
    for i in boundary_nodes:
        K[i, :] = 0
        K[:, i] = 0
        K[i, i] = 1
        F[i] = u[i]

    # === Solve system ===
    u = spsolve(K, F)
    
    return u, points, vertices, triangles, tri, u_bc


def less_first(a, b):
    return [a,b] if a < b else [b,a]

def delaunay2edges(tri):

    list_of_edges = []

    for triangle in tri.simplices:
        for e1, e2 in [[0,1],[1,2],[2,0]]: # for all edges of triangle
            list_of_edges.append(less_first(triangle[e1],triangle[e2])) # always lesser index first

    array_of_edges = np.unique(list_of_edges, axis=0) # remove duplicates

    list_of_lengths = []

    for p1,p2 in array_of_edges:
        x1, y1 = tri.points[p1]
        x2, y2 = tri.points[p2]
        list_of_lengths.append((x1-x2)**2 + (y1-y2)**2)

    array_of_lengths = np.sqrt(np.array(list_of_lengths))

    return array_of_edges, array_of_lengths

def make_sln_graph(l, h, u_bc=(u_left, u_right, u_top, u_bottom), res=10):
    u, points, vertices, triangles, tri, u_bc = mesh_and_solve(l, h, u_bc, res=res)
    edges, lengths = delaunay2edges(tri)
    faces = torch.tensor(tri.simplices, dtype=torch.long).T  # shape: (3, num_faces)

    data = Data(
        pos=torch.tensor(points).to(dtype=torch.float), # Node positions
        edge_index=torch.tensor(edges.T).to(dtype=torch.int64), # Use edge_index from Laplacian
        edge_attr=torch.tensor(lengths).to(dtype=torch.float), # Edge lengths
        face=torch.tensor(faces).to(dtype=torch.int64),
        x=torch.tensor(u).to(dtype=torch.float).reshape(-1,1), # Node features (steady-state temperatures)
        ubc=torch.tensor(u_bc).to(dtype=torch.float), # Dirichlet BC values
        lh = torch.tensor([l, h]).to(dtype=torch.float), # Domain size
    )
    return data

    
    
import torch
import matplotlib.pyplot as plt
import matplotlib.tri as mtri

def plot_slngraph(data, u, save=None, ObsIdx=None):
    points = data.pos.cpu().detach().numpy()
    faces = data.face.cpu().detach().numpy()

    # Plot filled contours
    contourf = plt.tricontourf(points[:, 0], points[:, 1], faces.T, u[:], levels=20, cmap='viridis')
    plt.colorbar(contourf)

    # Overlay contour lines (iso-contours)
    contour = plt.tricontour(points[:, 0], points[:, 1], faces.T, u[:], levels=20, colors='k', linewidths=0.5)

    # Draw mesh
    faces_torch = torch.tensor(faces.T, dtype=torch.long).T
    tri_plot = mtri.Triangulation(points[:, 0], points[:, 1], triangles=faces.T)
    plt.triplot(tri_plot, color='white', alpha=0.2)

    # Plot observed points if provided
    if ObsIdx is not None:
        plt.scatter(
            data.pos[:, 0][ObsIdx].detach().cpu().numpy(),
            data.pos[:, 1][ObsIdx].detach().cpu().numpy(),
            c='r', s=50, alpha=1.0, edgecolors='k'
        )

    # Final formatting
    plt.xlabel("$x_1$")
    plt.ylabel("$x_2$")
    plt.tight_layout()
    plt.gca().set_aspect('equal')

    if save is not None:
        plt.savefig(save, dpi=300, bbox_inches='tight')

    plt.show()
    plt.close()