import os
import sys
os.environ['CUDA_VISIBLE_DEVICES']='0'
import os
import numpy as np
import torch
import torch.nn as nn
import scipy.io as io
from sklearn.decomposition import PCA
from tqdm import tqdm

from typing import List

def our_static_forward(model, f, f_bc, x):
    y_br1 = model._branch1(f)
    y_br2 = model._branch2(f_bc)
    y_br = y_br1 * y_br2
    y_tr = model._trunk(x)
    y_out = torch.einsum("ij,kj->ik", y_br, y_tr)
    return y_out

def our_static_loss(model: nn.Module, f: torch.Tensor, f_bc: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    y_out = model.forward(f, f_bc, x)
    weights = torch.abs(y) + 1.0 
    loss = (weights * (y_out - y)**2).mean()
    return loss

def our_static_init(model: nn.Module, branch1_dim: List[int], branch2_dim: List[int], trunk_dim: List[int]) -> None:
    model.z_dim = trunk_dim[-1]
    modules = []
    in_channels = branch1_dim[0]
    for i, h_dim in enumerate(branch1_dim[1:]):
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        if i > 0 and h_dim == branch1_dim[i]:
            modules.append(nn.Linear(h_dim, h_dim))
        in_channels = h_dim
    model._branch1 = nn.Sequential(*modules)

    modules = []
    in_channels = branch2_dim[0]
    for i, h_dim in enumerate(branch2_dim[1:]):
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        if i > 0 and h_dim == branch2_dim[i]:
            modules.append(nn.Linear(h_dim, h_dim))
        in_channels = h_dim
    model._branch2 = nn.Sequential(*modules)

    modules = []
    in_channels = trunk_dim[0]
    for i, h_dim in enumerate(trunk_dim[1:]):
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        if i > 0 and h_dim == trunk_dim[i]:
            modules.append(nn.Linear(h_dim, h_dim))
        in_channels = h_dim
    model._trunk = nn.Sequential(*modules)

class OurModel(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super().__init__()
        our_static_init(self, branch1_dim, branch2_dim, trunk_dim)
    def forward(self, f, f_bc, x):
        return our_static_forward(self, f, f_bc, x)
    def loss(self, f, f_bc, x, y):
        return our_static_loss(self, f, f_bc, x, y)



def dimon_static_forward(model, f, f_bc, x):
    y_br1 = model._branch1(f)
    y_br2 = model._branch2(f_bc)
    y_br = y_br1 * y_br2
    y_tr = model._trunk(x)
    y_out = torch.einsum("ij,kj->ik", y_br, y_tr)
    return y_out

def dimon_static_loss(model: nn.Module, f: torch.Tensor, f_bc: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    y_out = model.forward(f, f_bc, x)
    loss = ((y_out - y)**2).mean()
    return loss

def dimon_static_init(model: nn.Module, branch1_dim: List[int], branch2_dim: List[int], trunk_dim: List[int]) -> None:
    model.z_dim = trunk_dim[-1]
    modules = []
    in_channels = branch1_dim[0]
    for h_dim in branch1_dim[1:]:
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        in_channels = h_dim
    model._branch1 = nn.Sequential(*modules)
    modules = []
    in_channels = branch2_dim[0]
    for h_dim in branch2_dim[1:]:
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        in_channels = h_dim
    model._branch2 = nn.Sequential(*modules)
    modules = []
    in_channels = trunk_dim[0]
    for h_dim in trunk_dim[1:]:
        modules.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
        in_channels = h_dim
    model._trunk = nn.Sequential(*modules)

class DIMONModel(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super().__init__()
        dimon_static_init(self, branch1_dim, branch2_dim, trunk_dim)
    def forward(self, f, f_bc, x):
        return dimon_static_forward(self, f, f_bc, x)
    def loss(self, f, f_bc, x, y):
        return dimon_static_loss(self, f, f_bc, x, y)




def load_and_prepare_data():

    paths = ["Laplace/dataset/Laplace_data.mat",
             "Laplace/dataset/Laplace_data_supp.mat",
             "Laplace/dataset/Laplace_data_supp2000.mat"]
    x_meshs = []
    dxs = []
    us = []
    u_bcs = []
    for path in paths:
        d = io.loadmat(path)
        x_meshs.append(d["x_mesh_data"])
        x = d["x_uni"]
        dxs.append(d["x_mesh_data"] - x) 
        us.append(d["u_data"])
        u_bcs.append(d["u_bc"])
    x_mesh = np.concatenate(x_meshs, axis=0)
    dx = np.concatenate(dxs, axis=0)
    u = np.concatenate(us, axis=0)
    u_bc = np.concatenate(u_bcs, axis=0)
    return x_mesh, dx, u, u_bc, x


def prepare_pca_features(dx_train, dx_test, PODMode=10):
    dx1_train = dx_train[:, :, 0]
    dx2_train = dx_train[:, :, 1]
    pca_x = PCA(n_components=PODMode)
    pca_x.fit(dx1_train - dx1_train.mean(axis=0))
    coeff_x_train = pca_x.transform(dx1_train - dx1_train.mean(axis=0))
    coeff_x_test = pca_x.transform(dx_test[:, :, 0] - dx1_train.mean(axis=0))
    pca_y = PCA(n_components=PODMode)
    pca_y.fit(dx2_train - dx2_train.mean(axis=0))
    coeff_y_train = pca_y.transform(dx2_train - dx2_train.mean(axis=0))
    coeff_y_test = pca_y.transform(dx_test[:, :, 1] - dx2_train.mean(axis=0))
    f_train = np.concatenate((coeff_x_train, coeff_y_train), axis=1)
    f_test = np.concatenate((coeff_x_test, coeff_y_test), axis=1)
    return f_train, f_test


def generate_perturbed_domains(base_shape_coeff, perturb_percents):
    
    perturbed_coeffs = []
    base_coeff = np.array(base_shape_coeff)
    for perc in perturb_percents:
        noise = np.random.normal(0, perc * np.linalg.norm(base_coeff), size=base_coeff.shape)
        perturbed_coeffs.append(base_coeff + noise)
    return perturbed_coeffs


def to_tensor(np_array, device):
    return torch.tensor(np_array, dtype=torch.float, device=device)


def compute_relative_l2_error(pred, base):
    
    return np.linalg.norm(pred - base) / np.linalg.norm(base)


def run_experiment():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    PODMode = 10
    num_bc = 68
    dim_br1 = [PODMode*2, 100, 100, 100]
    dim_br2 = [num_bc, 150, 150, 150, 100]
    dim_tr = [2, 100, 100, 100]
    epochs = 15000  

  
    x_mesh, dx, u, u_bc, x_ref = load_and_prepare_data()

    num_train = 3300
    num_total = dx.shape[0]

   
    dx_train = dx[:num_train]
    dx_test = dx[-200:]  

  
    f_train, f_test = prepare_pca_features(dx_train, dx_test, PODMode)

    u_train = u[:num_train]
    u_test = u[-200:]
    f_bc_train = u_bc[:num_train, ::3]  
    f_bc_test = u_bc[-200:, ::3]


    f_train_t = to_tensor(f_train, device)
    f_bc_train_t = to_tensor(f_bc_train, device)
    u_train_t = to_tensor(u_train, device)

    f_test_t = to_tensor(f_test, device)
    f_bc_test_t = to_tensor(f_bc_test, device)
    u_test_t = to_tensor(u_test, device)

    x_tensor = to_tensor(x_ref, device)

 
    our_model = OurModel(dim_br1, dim_br2, dim_tr).to(device)
    dimon_model = DIMONModel(dim_br1, dim_br2, dim_tr).to(device)

    opt_our = torch.optim.Adam(our_model.parameters(), lr=0.001)
    opt_dimon = torch.optim.Adam(dimon_model.parameters(), lr=0.001)

    # Training function
    def train_epoch(model, optimizer, f, f_bc, x, y):
        model.train()
        optimizer.zero_grad()
        loss = model.loss(f, f_bc, x, y)
        loss.backward()
        optimizer.step()
        return loss.item()

    for epoch in tqdm(range(epochs), desc="Training our_model"):
        train_loss = train_epoch(our_model, opt_our, f_train_t, f_bc_train_t, x_tensor, u_train_t)

    for epoch in tqdm(range(epochs), desc="Training dimon_model"):
        train_loss = train_epoch(dimon_model, opt_dimon, f_train_t, f_bc_train_t, x_tensor, u_train_t)


    num_base_domains = 10
    np.random.seed(42)
    base_indices = np.random.choice(range(f_test.shape[0]), num_base_domains, replace=False)
    perturb_percents = [0.01, 0.02, 0.05]  # 1%, 2%, 5%

    # Prepare results storage
    stability_data = {
        'perturbation_magnitude': perturb_percents,
        'our_method_relative_L2': {p: [] for p in perturb_percents},
        'baseline_relative_L2': {p: [] for p in perturb_percents},
        'our_method_stability_coefs': {p: [] for p in perturb_percents},
        'baseline_stability_coefs': {p: [] for p in perturb_percents}
    }

   
    for base_idx in base_indices:
        base_f = f_test[base_idx]
        base_f_bc = f_bc_test[base_idx]
        base_u_true = u_test[base_idx]

        base_f_t = to_tensor(base_f[None, :], device)
        base_f_bc_t = to_tensor(base_f_bc[None, :], device)

        # Predict on base domain
        our_base_pred = our_model.forward(base_f_t, base_f_bc_t, x_tensor).detach().cpu().numpy()[0]
        dimon_base_pred = dimon_model.forward(base_f_t, base_f_bc_t, x_tensor).detach().cpu().numpy()[0]

        for p in perturb_percents:
            noise = np.random.normal(scale=p * np.linalg.norm(base_f), size=base_f.shape)
            perturbed_f = base_f + noise
            noise_bc = np.random.normal(scale=0.01 * np.linalg.norm(base_f_bc), size=base_f_bc.shape)
            perturbed_f_bc = base_f_bc + noise_bc

            perturbed_f_t = to_tensor(perturbed_f[None, :], device)
            perturbed_f_bc_t = to_tensor(perturbed_f_bc[None, :], device)

            our_pert_pred = our_model.forward(perturbed_f_t, perturbed_f_bc_t, x_tensor).detach().cpu().numpy()[0]
            dimon_pert_pred = dimon_model.forward(perturbed_f_t, perturbed_f_bc_t, x_tensor).detach().cpu().numpy()[0]

            our_rel_l2 = compute_relative_l2_error(our_pert_pred, our_base_pred)
            dimon_rel_l2 = compute_relative_l2_error(dimon_pert_pred, dimon_base_pred)

            input_pert_mag = np.linalg.norm(perturbed_f - base_f) / np.linalg.norm(base_f)

            our_stab_coef = our_rel_l2 / input_pert_mag if input_pert_mag > 1e-10 else 0
            dimon_stab_coef = dimon_rel_l2 / input_pert_mag if input_pert_mag > 1e-10 else 0

            stability_data['our_method_relative_L2'][p].append(our_rel_l2)
            stability_data['baseline_relative_L2'][p].append(dimon_rel_l2)
            stability_data['our_method_stability_coefs'][p].append(our_stab_coef)
            stability_data['baseline_stability_coefs'][p].append(dimon_stab_coef)

    os.makedirs("./experimental_result_data", exist_ok=True)

    for p in perturb_percents:
        np.savetxt(f"./experimental_result_data/our_relL2_{int(p*100)}pct.txt", np.array(stability_data['our_method_relative_L2'][p]))
        np.savetxt(f"./experimental_result_data/baseline_relL2_{int(p*100)}pct.txt", np.array(stability_data['baseline_relative_L2'][p]))
        np.savetxt(f"./experimental_result_data/our_stabcoef_{int(p*100)}pct.txt", np.array(stability_data['our_method_stability_coefs'][p]))
        np.savetxt(f"./experimental_result_data/baseline_stabcoef_{int(p*100)}pct.txt", np.array(stability_data['baseline_stability_coefs'][p]))

    mean_stab_our = []
    max_stab_our = []
    mean_stab_base = []
    max_stab_base = []
    for p in perturb_percents:
        mean_stab_our.append(np.mean(stability_data['our_method_stability_coefs'][p]))
        max_stab_our.append(np.max(stability_data['our_method_stability_coefs'][p]))
        mean_stab_base.append(np.mean(stability_data['baseline_stability_coefs'][p]))
        max_stab_base.append(np.max(stability_data['baseline_stability_coefs'][p]))

    summary = {
        'perturbation_magnitude': perturb_percents,
        'mean_stability_coefficients_our': mean_stab_our,
        'max_stability_coefficients_our': max_stab_our,
        'mean_stability_coefficients_baseline': mean_stab_base,
        'max_stability_coefficients_baseline': max_stab_base
    }
    np.save(f"./experimental_result_data/summary_stability_coefficients.npy", summary)

if __name__ == "__main__":
    run_experiment()
