import os
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
from torch.autograd.functional import jacobian

def load_mat_list(datadir, varname, lyr, n, rd):

    fname = f"{varname}_lyr{lyr}n{n}test{rd}.mat"
    path  = os.path.join(datadir, fname)
    mat   = sio.loadmat(path, squeeze_me=True)
    if varname not in mat:
        raise KeyError(f"'{varname}' not found in {fname}. Keys: {list(mat.keys())}")

    arr = mat[varname]
    # unwrap any object‐dtype ndarray or Python list
    if isinstance(arr, list):
        elements = arr
    elif isinstance(arr, np.ndarray) and arr.dtype == object:
        elements = arr.flatten().tolist()
    else:
        raise ValueError(f"Expected a MATLAB cell‐array for '{varname}', got type={type(arr)}")

    # coerce each element to at least 1D np.ndarray
    result = []
    for e in elements:
        e_arr = np.array(e, copy=False)
        if e_arr.ndim == 0:
            e_arr = e_arr.reshape(1)
        result.append(e_arr)
    return result

class FFNet(nn.Module):
    """
    A feed‐forward net built from pre‐loaded weight & bias lists.
    """
    def __init__(self, Ws, bs, activation=nn.Tanh()): #nn.Tanh()
        super().__init__()
        self.Ws = nn.ParameterList([ nn.Parameter(W) for W in Ws ])
        self.bs = nn.ParameterList([ nn.Parameter(b) for b in bs ])
        self.act = activation

    def forward(self, x):
        a = x.view(-1,1)     # column vector
        for W, b in zip(self.Ws[:-1], self.bs[:-1]):
            a = self.act(W @ a + b)
        a = self.Ws[-1] @ a + self.bs[-1]
        return a.view(-1)    # flatten

if __name__ == "__main__":
    # 1) set your data directory and network identifiers
    data_dir      = "./datasets/random"
    lyrs = [5,10,15,20];
    neurons = [10,20,40,60];
    #lyrs = [15, 20, 25, 30, 35]
    #neurons = [40, 50, 60, 70]
    #lyrs = [30, 40, 50, 60, 70];
    #neurons = [60, 80, 100, 120];
    lyrs = [5,30,60]
    neurons = [128]
    
    for lyr in lyrs:
        for n in neurons:
            
            network_id = f'lyr{lyr}n{n}'
            
            fname = os.path.join(data_dir, network_id)
            data = sio.loadmat(fname)
            

            data = sio.loadmat(fname)
        
            weights = [W.squeeze() for W in data["weights"].flat]
            biases  = [W.squeeze() for W in data["biases"].flat]
        

            Ws = [ torch.tensor(W, dtype=torch.float32)       for W in weights ]
            bs = [ torch.tensor(b, dtype=torch.float32).view(-1,1) for b in biases ]


            net = FFNet(Ws, bs, activation=nn.LeakyReLU())

            

            
            x0_np = np.array([0.4,1.8,-0.5,-1.3,0.9],dtype=float)
            # 3) wrap as a torch tensor with grad enabled
            x0 = torch.tensor(x0_np, dtype=torch.float32, requires_grad=True)
            J  = jacobian(net, x0)   # shape: (output_dim, input_dim)
            J_flat = J.detach()
            lip_const = torch.linalg.norm(J_flat, ord=2)
            
          
            
            print(f"layer{lyr} neuron {n}")
            print("Lipschitz constant", lip_const.detach().cpu().numpy())
