import os
import sys

os.environ['CUDA_VISIBLE_DEVICES']='3'
import os
import json
import torch
import torch.nn as nn
import numpy as np
import scipy.io as io
from sklearn.decomposition import PCA
from tqdm import tqdm
from typing import List
from sklearn.metrics import mutual_info_score

def to_numpy(input):
    if isinstance(input, torch.Tensor):
        return input.detach().cpu().numpy()
    elif isinstance(input, np.ndarray):
        return input
    else:
        raise TypeError(f'Unknown type of input, expected torch.Tensor or np.ndarray, but got {type(input)}')

class OurMethodModel(nn.Module):
    def __init__(self, branch1_dim: List[int], branch2_dim: List[int], trunk_dim: List[int]):
        super().__init__()
        self.z_dim = trunk_dim[-1]
        modules_br1 = []
        in_channels = branch1_dim[0]
        for i, h_dim in enumerate(branch1_dim[1:]):
            modules_br1.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
            if i > 0 and h_dim == branch1_dim[i]:
                modules_br1.append(nn.Linear(h_dim, h_dim))
            in_channels = h_dim
        self._branch1 = nn.Sequential(*modules_br1)
        modules_br2 = []
        in_channels = branch2_dim[0]
        for i, h_dim in enumerate(branch2_dim[1:]):
            modules_br2.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
            if i > 0 and h_dim == branch2_dim[i]:
                modules_br2.append(nn.Linear(h_dim, h_dim))
            in_channels = h_dim
        self._branch2 = nn.Sequential(*modules_br2)
        modules_trunk = []
        in_channels = trunk_dim[0]
        for i, h_dim in enumerate(trunk_dim[1:]):
            modules_trunk.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
            if i > 0 and h_dim == trunk_dim[i]:
                modules_trunk.append(nn.Linear(h_dim, h_dim))
            in_channels = h_dim
        self._trunk = nn.Sequential(*modules_trunk)

    def forward(self, f, f_bc, x):
        y_br1 = self._branch1(f)
        y_br2 = self._branch2(f_bc)
        y_br = y_br1 * y_br2
        y_tr = self._trunk(x)
        y_out = torch.einsum('ij,kj->ik', y_br, y_tr)
        return y_out, y_br1, y_br2, y_br

    def loss(self, f, f_bc, x, y):
        y_out, _, _, _ = self.forward(f, f_bc, x)
        weights = torch.abs(y) + 1.0
        loss = (weights * (y_out - y)**2).mean()
        return loss

class DIMONModel(nn.Module):
    def __init__(self, branch1_dim: List[int], branch2_dim: List[int], trunk_dim: List[int]):
        super().__init__()
        self.z_dim = trunk_dim[-1]
        modules_br1 = []
        in_channels = branch1_dim[0]
        for h_dim in branch1_dim[1:]:
            modules_br1.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
            in_channels = h_dim
        self._branch1 = nn.Sequential(*modules_br1)

        modules_br2 = []
        in_channels = branch2_dim[0]
        for h_dim in branch2_dim[1:]:
            modules_br2.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
            in_channels = h_dim
        self._branch2 = nn.Sequential(*modules_br2)

        modules_trunk = []
        in_channels = trunk_dim[0]
        for h_dim in trunk_dim[1:]:
            modules_trunk.append(nn.Sequential(nn.Linear(in_channels, h_dim), nn.Tanh()))
            in_channels = h_dim
        self._trunk = nn.Sequential(*modules_trunk)

    def forward(self, f, f_bc, x):
        y_br1 = self._branch1(f)
        y_br2 = self._branch2(f_bc)
        y_br = y_br1 * y_br2
        y_tr = self._trunk(x)
        y_out = torch.einsum('ij,kj->ik', y_br, y_tr)
        return y_out, y_br1, y_br2, y_br

    def loss(self, f, f_bc, x, y):
        y_out, _, _, _ = self.forward(f, f_bc, x)
        loss = ((y_out - y)**2).mean()
        return loss


def relative_l2_diff(a: np.ndarray, b: np.ndarray) -> float:
    return np.linalg.norm(a - b) / (np.linalg.norm(b) + 1e-12)


def mutual_information(x: np.ndarray, y: np.ndarray, bins=20) -> float:
    c_xy = np.histogram2d(x, y, bins)[0]
    mi = mutual_info_score(None, None, contingency=c_xy)
    return mi


def prepare_dataset():
    datafile_main = 'Laplace/dataset/Laplace_data.mat'
    datafile_supp = 'Laplace/dataset/Laplace_data_supp.mat'
    datafile_supp2 = 'Laplace/dataset/Laplace_data_supp2000.mat'

    dataset_main = io.loadmat(datafile_main)
    dataset_supp = io.loadmat(datafile_supp)
    dataset_supp2 = io.loadmat(datafile_supp2)

    x = dataset_main['x_uni']
    u = dataset_main['u_data']
    u_bc = dataset_main['u_bc']
    x_mesh = dataset_main['x_mesh_data']

    u_bc = u_bc[:, ::3]  

    x_mesh_supp = dataset_supp['x_mesh_data']
    u_supp = dataset_supp['u_data']
    u_bc_supp = dataset_supp['u_bc']
    u_bc_supp = u_bc_supp[:, ::3]

    x_mesh_supp2 = dataset_supp2['x_mesh_data']
    u_supp2 = dataset_supp2['u_data']
    u_bc_supp2 = dataset_supp2['u_bc']
    u_bc_supp2 = u_bc_supp2[:, ::3]

    x_mesh = np.concatenate((x_mesh, x_mesh_supp, x_mesh_supp2), axis=0)
    u = np.concatenate((u, u_supp, u_supp2), axis=0)
    u_bc = np.concatenate((u_bc, u_bc_supp, u_bc_supp2), axis=0)

    dx = x_mesh - x[np.newaxis, :, :]

    num_total = dx.shape[0]
    num_train = int(num_total * 0.7)
    num_val = int(num_total * 0.15)
    num_test = num_total - num_train - num_val

    indices = np.arange(num_total)
    np.random.seed(42)
    np.random.shuffle(indices)

    train_idx = indices[:num_train]
    val_idx = indices[num_train:num_train+num_val]
    test_idx = indices[num_train+num_val:]

    PODMode = 10
    dx_train = dx[train_idx]
    dx1_train = dx_train[:, :, 0]
    dx2_train = dx_train[:, :, 1]

    pca_x = PCA(n_components=PODMode)
    pca_x.fit(dx1_train - dx1_train.mean(axis=0))
    coeff_x_train = pca_x.transform(dx1_train - dx1_train.mean(axis=0))

    pca_y = PCA(n_components=PODMode)
    pca_y.fit(dx2_train - dx2_train.mean(axis=0))
    coeff_y_train = pca_y.transform(dx2_train - dx2_train.mean(axis=0))

    f_train = np.concatenate((coeff_x_train, coeff_y_train), axis=1)

    f_bc_train = u_bc[train_idx]
    u_train = u[train_idx]

    f_bc_val = u_bc[val_idx]
    u_val = u[val_idx]

    dx_test = dx[test_idx]
    dx1_test = dx_test[:, :, 0]
    dx2_test = dx_test[:, :, 1]

    coeff_x_test = pca_x.transform(dx1_test - dx1_train.mean(axis=0))
    coeff_y_test = pca_y.transform(dx2_test - dx2_train.mean(axis=0))
    f_test = np.concatenate((coeff_x_test, coeff_y_test), axis=1)

    f_bc_test = u_bc[test_idx]
    u_test = u[test_idx]

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    f_train_t = torch.tensor(f_train, dtype=torch.float32).to(device)
    f_bc_train_t = torch.tensor(f_bc_train, dtype=torch.float32).to(device)
    u_train_t = torch.tensor(u_train, dtype=torch.float32).to(device)

    f_val_t = torch.tensor(f_test, dtype=torch.float32).to(device) 
    f_bc_val_t = torch.tensor(f_bc_val, dtype=torch.float32).to(device)
    u_val_t = torch.tensor(u_val, dtype=torch.float32).to(device)

    f_test_t = torch.tensor(f_test, dtype=torch.float32).to(device)
    f_bc_test_t = torch.tensor(f_bc_test, dtype=torch.float32).to(device)
    u_test_t = torch.tensor(u_test, dtype=torch.float32).to(device)

    x_t = torch.tensor(x, dtype=torch.float32).to(device)

    return {
        'train': (f_train_t, f_bc_train_t, u_train_t),
        'val': (f_val_t, f_bc_val_t, u_val_t),
        'test': (f_test_t, f_bc_test_t, u_test_t),
        'x': x_t,
        'device': device
    }


def train_model(model, train_data, val_data, x_tensor, epochs=300, lr=0.001):
    model = model.to(train_data['device'])
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_losses = []
    val_losses = []

    f_train, f_bc_train, u_train = train_data['train']
    f_val, f_bc_val, u_val = val_data['val']

    model.train()
    for epoch in tqdm(range(epochs), desc='Training'):
        optimizer.zero_grad()
        loss = model.loss(f_train, f_bc_train, x_tensor, u_train)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

        model.eval()
        with torch.no_grad():
            val_loss = model.loss(f_val, f_bc_val, x_tensor, u_val).item()
            val_losses.append(val_loss)
        model.train()

    return train_losses, val_losses


def perturb_branch_input(branch_input: torch.Tensor, perturb_type: str, noise_level: float, device) -> torch.Tensor:
    noise = torch.randn_like(branch_input) * noise_level
    perturbed = branch_input.clone()
    if perturb_type == 'branch1':
        perturbed = branch_input + noise
    elif perturb_type == 'branch2':
        perturbed = branch_input + noise
    else:
        raise ValueError('perturb_type must be branch1 or branch2')
    return perturbed.to(device)


def activation_and_output_under_perturbation(model, f_orig, f_bc_orig, x_tensor, perturb_branch, noise_levels, device):
    results = {
        'noise_level': [],
        'branch1_activation_change': [],
        'branch2_activation_change': [],
        'postfusion_output_deviation': [],
        'branch1_activation': [],
        'branch2_activation': [],
        'postfusion_output': []
    }

    f_orig_np = to_numpy(f_orig)
    f_bc_orig_np = to_numpy(f_bc_orig)

    model.eval()
    with torch.no_grad():
        y_out_base, y_br1_base, y_br2_base, y_br_base = model(f_orig, f_bc_orig, x_tensor)
    y_br1_base_np = to_numpy(y_br1_base)
    y_br2_base_np = to_numpy(y_br2_base)
    y_br_base_np = to_numpy(y_br_base)
    y_out_base_np = to_numpy(y_out_base)

    for noise_level in noise_levels:
        if perturb_branch == 'branch1':
            f_perturbed = perturb_branch_input(f_orig, 'branch1', noise_level, device)
            f_bc_perturbed = f_bc_orig
        elif perturb_branch == 'branch2':
            f_perturbed = f_orig
            f_bc_perturbed = perturb_branch_input(f_bc_orig, 'branch2', noise_level, device)
        else:
            raise ValueError('perturb_branch must be branch1 or branch2')

        with torch.no_grad():
            y_out_p, y_br1_p, y_br2_p, y_br_p = model(f_perturbed, f_bc_perturbed, x_tensor)
        y_br1_p_np = to_numpy(y_br1_p)
        y_br2_p_np = to_numpy(y_br2_p)
        y_br_p_np = to_numpy(y_br_p)
        y_out_p_np = to_numpy(y_out_p)

        br1_change = relative_l2_diff(y_br1_p_np, y_br1_base_np)
        br2_change = relative_l2_diff(y_br2_p_np, y_br2_base_np)
        output_dev = relative_l2_diff(y_out_p_np, y_out_base_np)

        results['noise_level'].append(noise_level)
        results['branch1_activation_change'].append(br1_change)
        results['branch2_activation_change'].append(br2_change)
        results['postfusion_output_deviation'].append(output_dev)
        results['branch1_activation'].append(y_br1_p_np)
        results['branch2_activation'].append(y_br2_p_np)
        results['postfusion_output'].append(y_out_p_np)

    for key in ['noise_level', 'branch1_activation_change', 'branch2_activation_change', 'postfusion_output_deviation']:
        results[key] = np.array(results[key], dtype=np.float64)

    return results


def compute_feature_decoupling_metrics(branch1_activations, branch2_activations):
    correlations = []
    mutual_infos = []

    for br1_act, br2_act in zip(branch1_activations, branch2_activations):
        corr_matrix = np.corrcoef(br1_act.T, br2_act.T)  
        f1 = br1_act.shape[1]
        f2 = br2_act.shape[1]
        cross_corr = corr_matrix[:f1, f1:f1+f2]
        mean_corr = np.mean(np.abs(cross_corr))
        correlations.append(mean_corr)

        br1_flat = np.mean(br1_act, axis=1)
        br2_flat = np.mean(br2_act, axis=1)
        mi = mutual_information(br1_flat, br2_flat)
        mutual_infos.append(mi)

    correlations = np.array(correlations, dtype=np.float64)
    mutual_infos = np.array(mutual_infos, dtype=np.float64)
    return correlations, mutual_infos


def evaluate_pde_errors(u_true, u_pred):
   
    abs_errors = np.abs(u_true - u_pred)
    rel_l2_errors = np.linalg.norm(u_true - u_pred, axis=1) / (np.linalg.norm(u_true, axis=1) + 1e-12)
    mean_abs_error = np.mean(abs_errors, axis=1)
    mean_rel_l2_error = np.mean(rel_l2_errors)  
    return mean_abs_error, rel_l2_errors



def experiment_pipeline():
    os.makedirs('./experimental_result_data', exist_ok=True)

  
    epochs = 300
    lr = 0.001
    PODMode = 10
    num_bc = 68
    dim_br1 = [PODMode*2, 80, 80, 80]
    dim_br2 = [num_bc, 120, 120, 120, 80]
    dim_tr = [2, 80, 80, 80]

    data = prepare_dataset()
    device = data['device']
    x_tensor = data['x']


    our_model = OurMethodModel(dim_br1, dim_br2, dim_tr).to(device)
    baseline_model = DIMONModel(dim_br1, dim_br2, dim_tr).to(device)

    train_losses_our, val_losses_our = train_model(our_model, data, data, x_tensor, epochs=epochs, lr=lr)
    train_losses_base, val_losses_base = train_model(baseline_model, data, data, x_tensor, epochs=epochs, lr=lr)

    noise_levels = np.linspace(0, 0.1, 11)  

    f_test_t, f_bc_test_t, u_test_t = data['test']

    results = {}
    for model_name, model in [('our_method', our_model), ('baseline', baseline_model)]:
        results[model_name] = {}
        for perturbed_branch in ['branch1', 'branch2']:
            res = activation_and_output_under_perturbation(model, f_test_t, f_bc_test_t, x_tensor, perturbed_branch, noise_levels, device)
            corr, mi = compute_feature_decoupling_metrics(res['branch1_activation'], res['branch2_activation'])

            pde_abs_errors = []
            pde_rel_l2_errors = []
            for u_pred in res['postfusion_output']:
                mean_abs_err, rel_l2_errs = evaluate_pde_errors(to_numpy(u_test_t), u_pred)
                pde_abs_errors.append(np.mean(mean_abs_err))
                pde_rel_l2_errors.append(np.mean(rel_l2_errs))

            results[model_name][perturbed_branch] = {
                'noise_level': res['noise_level'].tolist(),
                'branch1_activation_change': res['branch1_activation_change'].tolist(),
                'branch2_activation_change': res['branch2_activation_change'].tolist(),
                'postfusion_output_deviation': res['postfusion_output_deviation'].tolist(),
                'feature_correlation': corr.tolist(),
                'feature_mutual_information': mi.tolist(),
                'pde_mean_abs_error': pde_abs_errors,
                'pde_mean_rel_l2_error': pde_rel_l2_errors
            }

    for model_name in results.keys():
        for perturbed_branch in results[model_name].keys():
            data_dict = results[model_name][perturbed_branch]
            npz_path = f'./experimental_result_data/{model_name}_{perturbed_branch}_data.npz'
            np.savez(npz_path,
                     noise_level=np.array(data_dict['noise_level']),
                     branch1_activation_change=np.array(data_dict['branch1_activation_change']),
                     branch2_activation_change=np.array(data_dict['branch2_activation_change']),
                     postfusion_output_deviation=np.array(data_dict['postfusion_output_deviation']),
                     feature_correlation=np.array(data_dict['feature_correlation']),
                     feature_mutual_information=np.array(data_dict['feature_mutual_information']),
                     pde_mean_abs_error=np.array(data_dict['pde_mean_abs_error']),
                     pde_mean_rel_l2_error=np.array(data_dict['pde_mean_rel_l2_error']))


    return results

if __name__ == '__main__':
    results = experiment_pipeline()
    def convert_np_to_float(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, np.float64) or isinstance(obj, np.float32):
            return float(obj)
        if isinstance(obj, dict):
            return {k: convert_np_to_float(v) for k,v in obj.items()}
        if isinstance(obj, list):
            return [convert_np_to_float(i) for i in obj]
        return obj

    results_clean = convert_np_to_float(results)
    with open('./experimental_result_data/summary_results.json', 'w') as f:
        json.dump(results_clean, f, indent=4)
