import os
import sys

os.environ['CUDA_VISIBLE_DEVICES']='0'

import os
import json
import math
import time
import copy
import random
import numpy as np
import torch
import torch.nn as nn
from typing import List
from tqdm import trange

SEED = 2025
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

OUT_DIR = './experimental_result_data'
os.makedirs(OUT_DIR, exist_ok=True)

from dataclasses import dataclass

def static_forward_ours(model, f, f_bc, x):
    y_br1 = model._branch1(f)
    y_br2 = model._branch2(f_bc)
    y_br = y_br1 * y_br2
    y_tr = model._trunk(x)
    y_out = torch.einsum("ij,kj->ik", y_br, y_tr)
    return y_out


def static_loss_ours(model: nn.Module, f: torch.Tensor, f_bc: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    y_out = model.forward(f, f_bc, x)
    weights = torch.abs(y) + 1.0
    loss = (weights * (y_out - y)**2).mean()
    return loss


def static_init_ours(model: nn.Module, branch1_dim: List[int], branch2_dim: List[int], trunk_dim: List[int]) -> None:
    model.z_dim = trunk_dim[-1]
    modules = []
    in_channels = branch1_dim[0]
    for i, h_dim in enumerate(branch1_dim[1:]):
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        if i > 0 and h_dim == branch1_dim[i]:
            modules.append(nn.Linear(h_dim, h_dim))
        in_channels = h_dim
    model._branch1 = nn.Sequential(*modules)
    modules = []
    in_channels = branch2_dim[0]
    for i, h_dim in enumerate(branch2_dim[1:]):
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        if i > 0 and h_dim == branch2_dim[i]:
            modules.append(nn.Linear(h_dim, h_dim))
        in_channels = h_dim
    model._branch2 = nn.Sequential(*modules)
    modules = []
    in_channels = trunk_dim[0]
    for i, h_dim in enumerate(trunk_dim[1:]):
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        if i > 0 and h_dim == trunk_dim[i]:
            modules.append(nn.Linear(h_dim, h_dim))
        in_channels = h_dim
    model._trunk = nn.Sequential(*modules)



def DIMON_static_forward(model, f, f_bc, x):
    y_br1 = model._branch1(f)
    y_br2 = model._branch2(f_bc)
    y_br = y_br1 * y_br2
    y_tr = model._trunk(x)
    y_out = torch.einsum("ij,kj->ik", y_br, y_tr)
    return y_out


def DIMON_static_loss(model: nn.Module, f: torch.Tensor, f_bc: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    y_out = model.forward(f, f_bc, x)
    loss = ((y_out - y)**2).mean()
    return loss


def DIMON_static_init(model: nn.Module, branch1_dim: List[int], branch2_dim: List[int], trunk_dim: List[int]) -> None:
    model.z_dim = trunk_dim[-1]
    modules = []
    in_channels = branch1_dim[0]
    for h_dim in branch1_dim[1:]:
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        in_channels = h_dim
    model._branch1 = nn.Sequential(*modules)
    modules = []
    in_channels = branch2_dim[0]
    for h_dim in branch2_dim[1:]:
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        in_channels = h_dim
    model._branch2 = nn.Sequential(*modules)
    modules = []
    in_channels = trunk_dim[0]
    for h_dim in trunk_dim[1:]:
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        in_channels = h_dim
    model._trunk = nn.Sequential(*modules)


class OurOpNN(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super(OurOpNN, self).__init__()
        static_init_ours(self, branch1_dim, branch2_dim, trunk_dim)
    def forward(self, f, f_bc, x):
        return static_forward_ours(self, f, f_bc, x)
    def loss(self, f, f_bc, x, y):
        return static_loss_ours(self, f, f_bc, x, y)

class DIMONOpNN(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super(DIMONOpNN, self).__init__()
        DIMON_static_init(self, branch1_dim, branch2_dim, trunk_dim)
    def forward(self, f, f_bc, x):
        return DIMON_static_forward(self, f, f_bc, x)
    def loss(self, f, f_bc, x, y):
        return DIMON_static_loss(self, f, f_bc, x, y)


def generate_synthetic_dataset(num_cases=3500, num_points=100, PODMode=10, num_bc=68):
    x_uni = np.random.uniform(low=0.0, high=1.0, size=(num_points, 2)).astype(np.float32)
    base_modes = np.random.randn(PODMode, num_points).astype(np.float32)
    coeffs_x = np.random.randn(num_cases, PODMode).astype(np.float32) * 0.2
    coeffs_y = np.random.randn(num_cases, PODMode).astype(np.float32) * 0.2
    dx = np.zeros((num_cases, num_points, 2), dtype=np.float32)
    for i in range(num_cases):
        dx[i, :, 0] = (coeffs_x[i] @ base_modes).astype(np.float32)
        dx[i, :, 1] = (coeffs_y[i] @ base_modes).astype(np.float32)
    bc_coords = np.linspace(0, 1, num_bc).reshape(-1, 1)
    def exp_sq_kernel(a, b, lengthscale=0.5):
        d2 = np.sum((a - b)**2, axis=1)
        return np.exp(-d2 / (2 * lengthscale**2))
    K = np.zeros((num_bc, num_bc), dtype=np.float32)
    for i in range(num_bc):
        K[i, :] = exp_sq_kernel(bc_coords[i:i+1, :], bc_coords, lengthscale=0.5)
    K += 1e-6 * np.eye(num_bc)
    L = np.linalg.cholesky(K)
    u_bc = (np.random.randn(num_cases, num_bc).astype(np.float32) @ L.T)
    B = np.random.randn(num_bc, num_points).astype(np.float32) * 0.5
    phi = np.random.randn(PODMode, num_points).astype(np.float32)
    u = np.zeros((num_cases, num_points), dtype=np.float32)
    for i in range(num_cases):
        u[i] = coeffs_x[i] @ phi + coeffs_y[i] @ (phi * 0.5) + u_bc[i] @ B + 0.01 * np.random.randn(num_points)
    return {
        'x_uni': x_uni,
        'dx': dx,
        'u': u,
        'u_bc': u_bc,
        'bc_coords': bc_coords,
        'B': B,
        'phi': phi
    }

DATA = generate_synthetic_dataset(num_cases=3500, num_points=100, PODMode=10, num_bc=68)


num_cases = DATA['u'].shape[0]
indices = np.arange(num_cases)
np.random.shuffle(indices)

n_train = int(0.70 * num_cases)
n_val = int(0.15 * num_cases)

train_idx = indices[:n_train]
val_idx = indices[n_train:n_train + n_val]
test_idx = indices[n_train + n_val:]

from sklearn.decomposition import PCA
PODMode = 10
dx1 = DATA['dx'][:, :, 0]
dx2 = DATA['dx'][:, :, 1]

pca_x = PCA(n_components=PODMode)
pca_x.fit(dx1[train_idx] - dx1[train_idx].mean(axis=0))
coeff_x = pca_x.transform(dx1 - dx1[train_idx].mean(axis=0))

pca_y = PCA(n_components=PODMode)
pca_y.fit(dx2[train_idx] - dx2[train_idx].mean(axis=0))
coeff_y = pca_y.transform(dx2 - dx2[train_idx].mean(axis=0))

f_all = np.concatenate((coeff_x, coeff_y), axis=1).astype(np.float32)  

u_all = DATA['u'].astype(np.float32)

u_bc_all = DATA['u_bc'].astype(np.float32)

x_tensor = torch.tensor(DATA['x_uni'], dtype=torch.float32)  

f_train = torch.tensor(f_all[train_idx], dtype=torch.float32)
f_val = torch.tensor(f_all[val_idx], dtype=torch.float32)
f_test = torch.tensor(f_all[test_idx], dtype=torch.float32)

fbc_train = torch.tensor(u_bc_all[train_idx], dtype=torch.float32)
fbc_val = torch.tensor(u_bc_all[val_idx], dtype=torch.float32)
fbc_test = torch.tensor(u_bc_all[test_idx], dtype=torch.float32)

u_train = torch.tensor(u_all[train_idx], dtype=torch.float32)
u_val = torch.tensor(u_all[val_idx], dtype=torch.float32)
u_test = torch.tensor(u_all[test_idx], dtype=torch.float32)

x_train = x_tensor



def flatten_tensors(tensors):
    return torch.cat([t.reshape(-1) for t in tensors])


def get_param_vector(model):
    return flatten_tensors([p.data.view(-1) for p in model.parameters()])


def get_grad_vector(model):
    grads = [p.grad.view(-1) if p.grad is not None else torch.zeros_like(p.data).view(-1) for p in model.parameters()]
    return flatten_tensors(grads)


def grad_stats_from_model(model):
    layer_stats = []
    for p in model.parameters():
        if p.grad is None:
            g = torch.zeros_like(p.data)
        else:
            g = p.grad
        norm = torch.norm(g).item()
        var = torch.var(g).item()
        layer_stats.append({'norm': norm, 'var': var})
    return layer_stats


def cosine_similarity(a, b):
    if a is None or b is None:
        return float('nan')
    a = a.reshape(-1)
    b = b.reshape(-1)
    denom = (np.linalg.norm(a) * np.linalg.norm(b))
    if denom == 0:
        return float('nan')
    return float(np.dot(a, b) / denom)

def curvature_proxy(model, loss_fn, f_batch, fbc_batch, x, y_batch, n_samples=5, eps=1e-3):
    vHv = []

    theta = get_param_vector(model).detach().clone()
    grad0 = None
    model.zero_grad()
    loss0 = loss_fn(model, f_batch, fbc_batch, x, y_batch)
    loss0.backward()
    g0 = get_grad_vector(model).detach().cpu().numpy()
    for _ in range(n_samples):
        v = np.random.randn(theta.numel()).astype(np.float32)
        v = v / (np.linalg.norm(v) + 1e-12)
        idx = 0
        pert_params = []
        with torch.no_grad():
            for p in model.parameters():
                numel = p.numel()
                v_slice = torch.from_numpy(v[idx:idx+numel]).view_as(p.data).to(p.data.device)
                p.data.add_(eps * v_slice)
                pert_params.append((p, v_slice))
                idx += numel
        model.zero_grad()
        loss_pert = loss_fn(model, f_batch, fbc_batch, x, y_batch)
        loss_pert.backward()
        g1 = get_grad_vector(model).detach().cpu().numpy()
        v_h_v = float(np.dot((g1 - g0), v) / eps)
        vHv.append(v_h_v)
        with torch.no_grad():
            for p, v_slice in pert_params:
                p.data.sub_(eps * v_slice)
    return vHv



PODMode = 10
num_bc = 68
dim_br1 = [PODMode*2, 80, 80]
dim_br2 = [num_bc, 120, 120, 80]
dim_tr = [2, 80, 80]

N_EPOCHS = 200
BATCH_SIZE = 64
LR = 0.001
NUM_RUNS = 3  
DEVICE = torch.device('cpu')



def loss_fn_ours(model, f, f_bc, x, y):
    return static_loss_ours(model, f, f_bc, x)



train_indices = np.arange(n_train)
val_indices = np.arange(n_val)


def batch_generator(n_samples, batch_size, shuffle=True):
    idx = np.arange(n_samples)
    if shuffle:
        np.random.shuffle(idx)
    for i in range(0, n_samples, batch_size):
        yield idx[i:i+batch_size]


results_summary = {
    'meta': {
        'seed': SEED,
        'n_epochs': N_EPOCHS,
        'batch_size': BATCH_SIZE,
        'lr': LR,
        'num_runs': NUM_RUNS
    },
    'runs': []
}

for run in range(NUM_RUNS):
    run_entry = {'run_id': run, 'ours': {}, 'baseline': {}}
    print(f"Starting run {run+1}/{NUM_RUNS}")

    model_ours = OurOpNN(dim_br1, dim_br2, dim_tr).to(DEVICE).float()
    model_base = DIMONOpNN(dim_br1, dim_br2, dim_tr).to(DEVICE).float()

    opt_ours = torch.optim.Adam(model_ours.parameters(), lr=LR)
    opt_base = torch.optim.Adam(model_base.parameters(), lr=LR)

    f_train_dev = f_train.to(DEVICE)
    fbc_train_dev = fbc_train.to(DEVICE)
    u_train_dev = u_train.to(DEVICE)

    f_val_dev = f_val.to(DEVICE)
    fbc_val_dev = fbc_val.to(DEVICE)
    u_val_dev = u_val.to(DEVICE)

    stats_ours = {
        'epoch': [],
        'train_loss': [],
        'val_loss': [],
        'layer_grad_norms': [],  
        'layer_grad_vars': [],
        'grad_flat_cosine_prev': [],
        'param_update_norms': [],
        'curvature_vHv_mean': [],
        'curvature_vHv_max': []
    }
    stats_base = copy.deepcopy(stats_ours)

    prev_grad_ours = None
    prev_grad_base = None

    prev_params_ours = get_param_vector(model_ours).detach().numpy()
    prev_params_base = get_param_vector(model_base).detach().numpy()

    for epoch in trange(N_EPOCHS, desc=f"Run {run+1}"):
        for batch_idx in batch_generator(n_train, BATCH_SIZE, shuffle=True):
            f_batch = f_train_dev[batch_idx]
            fbc_batch = fbc_train_dev[batch_idx]
            y_batch = u_train_dev[batch_idx]

            model_ours.train()
            opt_ours.zero_grad()
            y_pred_ours = model_ours.forward(f_batch, fbc_batch, x_train)
            weights = torch.abs(y_batch) + 1.0
            loss_ours = (weights * (y_pred_ours - y_batch)**2).mean()
            loss_ours.backward()
            opt_ours.step()

            model_base.train()
            opt_base.zero_grad()
            y_pred_base = model_base.forward(f_batch, fbc_batch, x_train)
            loss_base = ((y_pred_base - y_batch)**2).mean()
            loss_base.backward()
            opt_base.step()

        model_ours.eval()
        model_base.eval()
        with torch.no_grad():
            train_loss_ours_epoch = (( (model_ours.forward(f_train_dev, fbc_train_dev, x_train) - u_train_dev)**2) * (torch.abs(u_train_dev)+1.0)).mean().item()
            val_loss_ours_epoch = (( (model_ours.forward(f_val_dev, fbc_val_dev, x_train) - u_val_dev)**2) * (torch.abs(u_val_dev)+1.0)).mean().item()

            train_loss_base_epoch = (((model_base.forward(f_train_dev, fbc_train_dev, x_train) - u_train_dev)**2)).mean().item()
            val_loss_base_epoch = (((model_base.forward(f_val_dev, fbc_val_dev, x_train) - u_val_dev)**2)).mean().item()

        subset_idx = np.random.choice(n_train, size=min(128, n_train), replace=False)
        f_sub = f_train_dev[subset_idx]
        fbc_sub = fbc_train_dev[subset_idx]
        y_sub = u_train_dev[subset_idx]

        model_ours.zero_grad()
        loss_sub_ours = ((torch.abs(y_sub)+1.0) * (model_ours.forward(f_sub, fbc_sub, x_train) - y_sub)**2).mean()
        loss_sub_ours.backward()
        grad_vec_ours = get_grad_vector(model_ours).detach().cpu().numpy()
        layer_stats_ours = grad_stats_from_model(model_ours)

        model_base.zero_grad()
        loss_sub_base = (((model_base.forward(f_sub, fbc_sub, x_train) - y_sub)**2)).mean()
        loss_sub_base.backward()
        grad_vec_base = get_grad_vector(model_base).detach().cpu().numpy()
        layer_stats_base = grad_stats_from_model(model_base)

        cos_ours = cosine_similarity(grad_vec_ours, prev_grad_ours) if prev_grad_ours is not None else float('nan')
        cos_base = cosine_similarity(grad_vec_base, prev_grad_base) if prev_grad_base is not None else float('nan')
        prev_grad_ours = grad_vec_ours.copy()
        prev_grad_base = grad_vec_base.copy()

        cur_params_ours = get_param_vector(model_ours).detach().numpy()
        update_norm_ours = float(np.linalg.norm(cur_params_ours - prev_params_ours))
        prev_params_ours = cur_params_ours.copy()

        cur_params_base = get_param_vector(model_base).detach().numpy()
        update_norm_base = float(np.linalg.norm(cur_params_base - prev_params_base))
        prev_params_base = cur_params_base.copy()

        model_ours.zero_grad()
        vHv_ours = curvature_proxy(model_ours, lambda m,f_,fb_,x_,y_: ((torch.abs(y_)+1.0) * (m.forward(f_, fb_, x_) - y_)**2).mean(), f_sub, fbc_sub, x_train, y_sub, n_samples=4, eps=1e-3)
        model_base.zero_grad()
        vHv_base = curvature_proxy(model_base, lambda m,f_,fb_,x_,y_: (((m.forward(f_, fb_, x_) - y_)**2)).mean(), f_sub, fbc_sub, x_train, y_sub, n_samples=4, eps=1e-3)

        stats_ours['epoch'].append(epoch)
        stats_ours['train_loss'].append(train_loss_ours_epoch)
        stats_ours['val_loss'].append(val_loss_ours_epoch)
        stats_ours['layer_grad_norms'].append([ls['norm'] for ls in layer_stats_ours])
        stats_ours['layer_grad_vars'].append([ls['var'] for ls in layer_stats_ours])
        stats_ours['grad_flat_cosine_prev'].append(cos_ours)
        stats_ours['param_update_norms'].append(update_norm_ours)
        stats_ours['curvature_vHv_mean'].append(float(np.mean(vHv_ours)))
        stats_ours['curvature_vHv_max'].append(float(np.max(vHv_ours)))

        stats_base['epoch'].append(epoch)
        stats_base['train_loss'].append(train_loss_base_epoch)
        stats_base['val_loss'].append(val_loss_base_epoch)
        stats_base['layer_grad_norms'].append([ls['norm'] for ls in layer_stats_base])
        stats_base['layer_grad_vars'].append([ls['var'] for ls in layer_stats_base])
        stats_base['grad_flat_cosine_prev'].append(cos_base)
        stats_base['param_update_norms'].append(update_norm_base)
        stats_base['curvature_vHv_mean'].append(float(np.mean(vHv_base)))
        stats_base['curvature_vHv_max'].append(float(np.max(vHv_base)))

    model_ours.eval()
    model_base.eval()
    with torch.no_grad():
        u_test_pred_ours = model_ours.forward(f_test.to(DEVICE), fbc_test.to(DEVICE), x_train).cpu().numpy()
        u_test_true = u_test.numpy()
        rel_l2_err_ours = np.linalg.norm(u_test_true - u_test_pred_ours, axis=1) / (np.linalg.norm(u_test_true, axis=1) + 1e-12)

        u_test_pred_base = model_base.forward(f_test.to(DEVICE), fbc_test.to(DEVICE), x_train).cpu().numpy()
        rel_l2_err_base = np.linalg.norm(u_test_true - u_test_pred_base, axis=1) / (np.linalg.norm(u_test_true, axis=1) + 1e-12)

    timestamp = int(time.time())
    fname_ours = f'run{run}_ours_stats.npz'
    fname_base = f'run{run}_base_stats.npz'
    np.savez_compressed(os.path.join(OUT_DIR, fname_ours), **stats_ours)
    np.savez_compressed(os.path.join(OUT_DIR, fname_base), **stats_base)

    summary = {
        'train_loss_ours': stats_ours['train_loss'],
        'val_loss_ours': stats_ours['val_loss'],
        'train_loss_base': stats_base['train_loss'],
        'val_loss_base': stats_base['val_loss'],
        'mean_rel_l2_test_ours': float(rel_l2_err_ours.mean()),
        'mean_rel_l2_test_base': float(rel_l2_err_base.mean()),
        'rel_l2_errs_ours': rel_l2_err_ours.tolist(),
        'rel_l2_errs_base': rel_l2_err_base.tolist(),
    }
    fname_summary = f'run{run}_summary.json'
    with open(os.path.join(OUT_DIR, fname_summary), 'w') as f:
        json.dump(summary, f)

    run_entry['ours']['stats_file'] = fname_ours
    run_entry['ours']['mean_rel_l2_test'] = float(rel_l2_err_ours.mean())
    run_entry['baseline']['stats_file'] = fname_base
    run_entry['baseline']['mean_rel_l2_test'] = float(rel_l2_err_base.mean())
    run_entry['summary_file'] = fname_summary
    results_summary['runs'].append(run_entry)

with open(os.path.join(OUT_DIR, 'results_summary.json'), 'w') as f:
    json.dump(results_summary, f)


all_metrics = []
for run in range(NUM_RUNS):
    stats_path = os.path.join(OUT_DIR, f'run{run}_ours_stats.npz')
    base_path = os.path.join(OUT_DIR, f'run{run}_base_stats.npz')
    sum_path = os.path.join(OUT_DIR, f'run{run}_summary.json')
    if not os.path.exists(stats_path) or not os.path.exists(base_path):
        continue
    stats_ours = np.load(stats_path, allow_pickle=True)
    stats_base = np.load(base_path, allow_pickle=True)
    with open(sum_path, 'r') as f:
        s = json.load(f)
    mean_grad_norm_ours = float(np.mean(np.array(stats_ours['layer_grad_norms'].tolist()).astype(np.float32)))
    mean_curv_ours = float(np.mean(np.array(stats_ours['curvature_vHv_mean'].tolist()).astype(np.float32)))
    mean_update_ours = float(np.mean(np.array(stats_ours['param_update_norms'].tolist()).astype(np.float32)))

    mean_grad_norm_base = float(np.mean(np.array(stats_base['layer_grad_norms'].tolist()).astype(np.float32)))
    mean_curv_base = float(np.mean(np.array(stats_base['curvature_vHv_mean'].tolist()).astype(np.float32)))
    mean_update_base = float(np.mean(np.array(stats_base['param_update_norms'].tolist()).astype(np.float32)))

    all_metrics.append({
        'run': run,
        'mean_grad_norm_ours': mean_grad_norm_ours,
        'mean_curv_ours': mean_curv_ours,
        'mean_update_ours': mean_update_ours,
        'mean_grad_norm_base': mean_grad_norm_base,
        'mean_curv_base': mean_curv_base,
        'mean_update_base': mean_update_base,
        'mean_rel_l2_test_ours': s['mean_rel_l2_test_ours'],
        'mean_rel_l2_test_base': s['mean_rel_l2_test_base']
    })

with open(os.path.join(OUT_DIR, 'correlation_metrics.json'), 'w') as f:
    json.dump(all_metrics, f)


print('\nExperiment finished. Results saved under:', OUT_DIR)

