import os
import sys

os.environ['CUDA_VISIBLE_DEVICES']='3'
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import scipy.io as sio
from sklearn.decomposition import PCA
from tqdm import trange

SAVE_DIR = './draw_data'
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

RANDOM_SEEDS = [123, 456, 789]  
BATCH_SIZE = 16                 
EPOCHS = 2                      
GRAD_NOISE_ESTIM_BATCHES = 30   

PODMode = 10
num_bc = 68
skip = 3
branch1_dim = [PODMode*2, 96, 64]
branch2_dim = [num_bc, 112, 80, 64]
trunk_dim = [2, 64, 60, 64]

def all_torch(*args, device):
    return tuple(torch.tensor(x, dtype=torch.float32, device=device) for x in args)


def load_data(dev='cpu'):
    datafile = "Laplace/dataset/Laplace_data.mat"
    datafile_supp = "Laplace/dataset/Laplace_data_supp.mat"
    datafile_supp2 = "Laplace/dataset/Laplace_data_supp2000.mat"
    
    d1 = sio.loadmat(datafile)
    d2 = sio.loadmat(datafile_supp)
    d3 = sio.loadmat(datafile_supp2)

    x = d1["x_uni"]           
    x_mesh = np.concatenate([d1["x_mesh_data"], d2["x_mesh_data"], d3["x_mesh_data"]], axis=0)
    dx = np.concatenate([d1["x_mesh_data"]-x, d2["x_mesh_data"]-x, d3["x_mesh_data"]-x], axis=0)
    u = np.concatenate([d1["u_data"], d2["u_data"], d3["u_data"]], axis=0)
    u_bc = np.concatenate([d1["u_bc"][:,::skip], d2["u_bc"][:,::skip], d3["u_bc"][:,::skip]], axis=0)

    num_train = 3300
    num_test = 200

    dx_train = dx[:num_train]
    dx_test = dx[-num_test:]
    dx1_train = dx_train[:,:,0]
    dx2_train = dx_train[:,:,1]
    pca_x = PCA(n_components=PODMode)
    pca_x.fit(dx1_train - dx1_train.mean(axis=0))
    coeff_x_train = pca_x.transform(dx1_train - dx1_train.mean(axis=0))
    coeff_x_test = pca_x.transform(dx_test[:,:,0] - dx1_train.mean(axis=0))

    pca_y = PCA(n_components=PODMode)
    pca_y.fit(dx2_train - dx2_train.mean(axis=0))
    coeff_y_train = pca_y.transform(dx2_train - dx2_train.mean(axis=0))
    coeff_y_test = pca_y.transform(dx_test[:,:,1] - dx2_train.mean(axis=0))

    f_train = np.concatenate([coeff_x_train, coeff_y_train], axis=1)  
    f_test = np.concatenate([coeff_x_test, coeff_y_test], axis=1)
    u_train = u[:num_train]
    u_test = u[-num_test:]
    f_bc_train = u_bc[:num_train, :]
    f_bc_test = u_bc[-num_test:, :]

    x = x.astype(np.float32)         
    f_train = f_train.astype(np.float32)
    f_bc_train = f_bc_train.astype(np.float32)
    u_train = u_train.astype(np.float32)

    return (torch.tensor(f_train,    dtype=torch.float32, device=dev),
            torch.tensor(f_bc_train, dtype=torch.float32, device=dev),
            torch.tensor(x,         dtype=torch.float32, device=dev),
            torch.tensor(u_train,   dtype=torch.float32, device=dev))

class OursNet(nn.Module):
    def __init__(self):
        super().__init__()
        self._branch1 = nn.Sequential(
            nn.Linear(branch1_dim[0], branch1_dim[1]), nn.Tanh(),
            nn.Linear(branch1_dim[1], branch1_dim[2]), nn.Tanh(),
        )
        self._branch2 = nn.Sequential(
            nn.Linear(branch2_dim[0], branch2_dim[1]), nn.Tanh(),
            nn.Linear(branch2_dim[1], branch2_dim[2]), nn.Tanh(),
            nn.Linear(branch2_dim[2], branch2_dim[3]), nn.Tanh(),
        )
        self._trunk = nn.Sequential(
            nn.Linear(trunk_dim[0], trunk_dim[1]), nn.Tanh(),
            nn.Linear(trunk_dim[1], trunk_dim[2]), nn.Tanh(),
            nn.Linear(trunk_dim[2], trunk_dim[3]), nn.Tanh(),
        )
    def forward(self, f, f_bc, x):
        y_br1 = self._branch1(f)
        y_br2 = self._branch2(f_bc)
        if y_br1.shape[0] != y_br2.shape[0]:
            n = min(y_br1.shape[0], y_br2.shape[0])
            y_br1 = y_br1[:n]
            y_br2 = y_br2[:n]
        y_br = y_br1 * y_br2
        y_tr = self._trunk(x)       
        out = torch.einsum('ij,kj->ik', y_br, y_tr)
        return out
    def weighted_loss(self, f, f_bc, x, y):
        pred = self.forward(f, f_bc, x)
        weights = torch.abs(y) + 1.0
        return ((weights * (pred - y)**2).mean())
    def plain_loss(self, f, f_bc, x, y):
        pred = self.forward(f, f_bc, x)
        return ((pred - y)**2).mean()

class DIMONNet(nn.Module):
    def __init__(self):
        super().__init__()
        self._branch1 = nn.Sequential(
            nn.Linear(branch1_dim[0], branch1_dim[1]), nn.Tanh(),
            nn.Linear(branch1_dim[1], branch1_dim[2]), nn.Tanh(),
        )
        self._branch2 = nn.Sequential(
            nn.Linear(branch2_dim[0], branch2_dim[1]), nn.Tanh(),
            nn.Linear(branch2_dim[1], branch2_dim[2]), nn.Tanh(),
            nn.Linear(branch2_dim[2], branch2_dim[3]), nn.Tanh(),
        )
        self._trunk = nn.Sequential(
            nn.Linear(trunk_dim[0], trunk_dim[1]), nn.Tanh(),
            nn.Linear(trunk_dim[1], trunk_dim[2]), nn.Tanh(),
            nn.Linear(trunk_dim[2], trunk_dim[3]), nn.Tanh(),
        )
    def forward(self, f, f_bc, x):
        y_br1 = self._branch1(f)
        y_br2 = self._branch2(f_bc)
        if y_br1.shape[0] != y_br2.shape[0]:
            n = min(y_br1.shape[0], y_br2.shape[0])
            y_br1 = y_br1[:n]
            y_br2 = y_br2[:n]
        y_br = y_br1 * y_br2
        y_tr = self._trunk(x)
        out = torch.einsum('ij,kj->ik', y_br, y_tr)
        return out
    def plain_loss(self, f, f_bc, x, y):
        pred = self.forward(f, f_bc, x)
        return ((pred - y)**2).mean()

def compute_layerwise_grad_noise(model, loss_fn, f_tr, fbc_tr, x_uni, u_tr, device):
  
    layers = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
    grad_noise_stats = {n: [] for n,_ in layers}
    N = f_tr.shape[0]
    model.eval()

    grad_list_per_layer = {n: [] for n,_ in layers}

    for k in range(GRAD_NOISE_ESTIM_BATCHES):
        idx = torch.randperm(N)[:BATCH_SIZE]
        f_b    = f_tr[idx]
        f_bc_b = fbc_tr[idx]
        u_b    = u_tr[idx]
        model.zero_grad()
        loss = loss_fn(f_b, f_bc_b, x_uni, u_b)
        loss.backward()
        for n, p in layers:
            if p.grad is not None:
                grad_list_per_layer[n].append(p.grad.detach().cpu().clone().flatten())

    for n in grad_list_per_layer.keys():
        if len(grad_list_per_layer[n]) == 0:
            grad_noise_stats[n] = (0.0, 0.0)
        else:
            grads = torch.stack(grad_list_per_layer[n], dim=0)
            grads_std = grads.std(dim=0)                            
            layer_std = grads_std.mean().item()                    
            layer_mean = grads.mean().abs().mean().item()           
            grad_noise_stats[n] = (layer_std, layer_mean)
    return grad_noise_stats


def aggregate_grads(grads_list, keys):
    out_std = {k: [] for k in keys}
    out_mean = {k: [] for k in keys}
    for d in grads_list:
        for k in keys:
            std, mean = d[k]
            out_std[k].append(std)
            out_mean[k].append(mean)
    return out_std, out_mean

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    f_tr, fbc_tr, x_uni, u_tr = load_data(dev=device)
    noise_results_ours = []
    noise_results_dimon = []
    for rs in RANDOM_SEEDS:
        torch.manual_seed(rs)
        np.random.seed(rs)
        net_ours = OursNet().to(device)
        net_dimon = DIMONNet().to(device)
        ours_stats = compute_layerwise_grad_noise(
            net_ours, lambda ff1, ff2, xx, yy: net_ours.weighted_loss(ff1, ff2, xx, yy),
            f_tr, fbc_tr, x_uni, u_tr, device)
        dimon_stats = compute_layerwise_grad_noise(
            net_dimon, lambda ff1, ff2, xx, yy: net_dimon.plain_loss(ff1, ff2, xx, yy),
            f_tr, fbc_tr, x_uni, u_tr, device)
        noise_results_ours.append(ours_stats)
        noise_results_dimon.append(dimon_stats)
    layer_keys = list([n for n, _ in OursNet().named_parameters()])
    ours_std, ours_mean = aggregate_grads(noise_results_ours, layer_keys)
    dimon_std, dimon_mean = aggregate_grads(noise_results_dimon, layer_keys)
    layer_names = layer_keys
    data = {
        'layer_names': layer_names,
        'ours_mean_std':  [float(np.mean(ours_std[n])) for n in layer_names],
        'ours_std_std':   [float(np.std(ours_std[n])) for n in layer_names],
        'dimon_mean_std': [float(np.mean(dimon_std[n])) for n in layer_names],
        'dimon_std_std':  [float(np.std(dimon_std[n])) for n in layer_names],
    }
 
    fn = os.path.join(SAVE_DIR, 'layerwise_gradient_noise_statistics.npz')
    np.savez(fn, **data)

if __name__ == '__main__':
    main()
