import os
import sys

os.environ['CUDA_VISIBLE_DEVICES']='4'
import os
import torch
import torch.nn as nn
import numpy as np
import scipy.io as io
from sklearn.decomposition import PCA
from tqdm import tqdm

def to_numpy(input):
    if isinstance(input, torch.Tensor):
        return input.detach().cpu().numpy()
    elif isinstance(input, np.ndarray):
        return input
    else:
        raise TypeError('Unknown type of input, expected torch.Tensor or np.ndarray, but got {}'.format(type(input)))

class OursOpNN(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super().__init__()
        self.z_dim = trunk_dim[-1]
        modules = []
        in_channels = branch1_dim[0]
        for i, h_dim in enumerate(branch1_dim[1:]):
            modules.append(nn.Sequential(
                nn.Linear(in_channels, h_dim), nn.Tanh()
            ))
            if i > 0 and h_dim == branch1_dim[i]:
                modules.append(nn.Linear(h_dim, h_dim))
            in_channels = h_dim
        self._branch1 = nn.Sequential(*modules)
        modules = []
        in_channels = branch2_dim[0]
        for i, h_dim in enumerate(branch2_dim[1:]):
            modules.append(nn.Sequential(
                nn.Linear(in_channels, h_dim), nn.Tanh()
            ))
            if i > 0 and h_dim == branch2_dim[i]:
                modules.append(nn.Linear(h_dim, h_dim))
            in_channels = h_dim
        self._branch2 = nn.Sequential(*modules)
        modules = []
        in_channels = trunk_dim[0]
        for i, h_dim in enumerate(trunk_dim[1:]):
            modules.append(nn.Sequential(
                nn.Linear(in_channels, h_dim), nn.Tanh()
            ))
            if i > 0 and h_dim == trunk_dim[i]:
                modules.append(nn.Linear(h_dim, h_dim))
            in_channels = h_dim
        self._trunk = nn.Sequential(*modules)
    def forward(self, f, f_bc, x):
        y_br1 = self._branch1(f)
        y_br2 = self._branch2(f_bc)
        y_br = y_br1 * y_br2
        y_tr = self._trunk(x)
        y_out = torch.einsum("ij,kj->ik", y_br, y_tr)
        return y_out
    def weighted_loss(self, f, f_bc, x, y):
        y_out = self.forward(f, f_bc, x)
        weights = torch.abs(y) + 1.0
        return (weights * (y_out - y) ** 2).mean(), weights * (y_out - y) ** 2
    def uniform_loss(self, f, f_bc, x, y):
        y_out = self.forward(f, f_bc, x)
        return ((y_out - y) ** 2).mean()

class DIMONOpNN(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super().__init__()
        self.z_dim = trunk_dim[-1]
        modules = []
        in_channels = branch1_dim[0]
        for h_dim in branch1_dim[1:]:
            modules.append(nn.Sequential(
                nn.Linear(in_channels, h_dim), nn.Tanh()
            ))
            in_channels = h_dim
        self._branch1 = nn.Sequential(*modules)
        modules = []
        in_channels = branch2_dim[0]
        for h_dim in branch2_dim[1:]:
            modules.append(nn.Sequential(
                nn.Linear(in_channels, h_dim), nn.Tanh()
            ))
            in_channels = h_dim
        self._branch2 = nn.Sequential(*modules)
        modules = []
        in_channels = trunk_dim[0]
        for h_dim in trunk_dim[1:]:
            modules.append(nn.Sequential(
                nn.Linear(in_channels, h_dim), nn.Tanh()
            ))
            in_channels = h_dim
        self._trunk = nn.Sequential(*modules)
    def forward(self, f, f_bc, x):
        y_br1 = self._branch1(f)
        y_br2 = self._branch2(f_bc)
        y_br = y_br1 * y_br2
        y_tr = self._trunk(x)
        y_out = torch.einsum("ij,kj->ik", y_br, y_tr)
        return y_out
    def loss(self, f, f_bc, x, y):
        y_out = self.forward(f, f_bc, x)
        return ((y_out - y) ** 2).mean(), (y_out - y) ** 2

def get_datasets(PODMode=10, num_bc=68, num_train=3300, num_test=200, skip=3, device='cuda'):
    datafile = "Laplace/dataset/Laplace_data.mat"
    dataset = io.loadmat(datafile)
    x = dataset["x_uni"]
    x_mesh = dataset["x_mesh_data"]
    dx = x_mesh - x
    u = dataset["u_data"]
    u_bc = dataset["u_bc"][:, ::skip]
    datafile_supp = "Laplace/dataset/Laplace_data_supp.mat"
    dataset_supp = io.loadmat(datafile_supp)
    x_mesh_supp = dataset_supp["x_mesh_data"]
    dx_supp = x_mesh_supp - x
    u_supp = dataset_supp["u_data"]
    u_bc_supp = dataset_supp["u_bc"][:, ::skip]
    datafile_supp2 = "Laplace/dataset/Laplace_data_supp2000.mat"
    dataset_supp2 = io.loadmat(datafile_supp2)
    x_mesh_supp2 = dataset_supp2["x_mesh_data"]
    dx_supp2 = x_mesh_supp2 - x
    u_supp2 = dataset_supp2["u_data"]
    u_bc_supp2 = dataset_supp2["u_bc"][:, ::skip]
    x_mesh = np.concatenate((x_mesh, x_mesh_supp, x_mesh_supp2), axis=0)
    dx = np.concatenate((dx, dx_supp, dx_supp2), axis=0)
    u = np.concatenate((u, u_supp, u_supp2), axis=0)
    u_bc = np.concatenate((u_bc, u_bc_supp, u_bc_supp2), axis=0)
    dx_train = dx[:num_train]
    dx_test = dx[-num_test:]
    dx1_train = dx_train[:, :, 0]
    dx2_train = dx_train[:, :, 1]
    pca_x = PCA(n_components=PODMode)
    pca_x.fit(dx1_train - dx1_train.mean(axis=0))
    coeff_x_train = pca_x.transform(dx1_train - dx1_train.mean(axis=0))
    coeff_x_test = pca_x.transform(dx_test[:, :, 0] - dx1_train.mean(axis=0))
    pca_y = PCA(n_components=PODMode)
    pca_y.fit(dx2_train - dx2_train.mean(axis=0))
    coeff_y_train = pca_y.transform(dx2_train - dx2_train.mean(axis=0))
    coeff_y_test = pca_y.transform(dx_test[:, :, 1] - dx2_train.mean(axis=0))
    f_train = np.concatenate((coeff_x_train, coeff_y_train), axis=1)
    f_test = np.concatenate((coeff_x_test, coeff_y_test), axis=1)
    u_train = u[:num_train]
    u_test = u[-num_test:]
    f_bc_train = u_bc[:num_train, :]
    f_bc_test = u_bc[-num_test:, :]
    x_tensor = torch.tensor(x, dtype=torch.float).to(device)
    u_train_tensor = torch.tensor(u_train, dtype=torch.float).to(device)
    f_train_tensor = torch.tensor(f_train, dtype=torch.float).to(device)
    f_bc_train_tensor = torch.tensor(f_bc_train, dtype=torch.float).to(device)
    u_test_tensor = torch.tensor(u_test, dtype=torch.float).to(device)
    f_test_tensor = torch.tensor(f_test, dtype=torch.float).to(device)
    f_bc_test_tensor = torch.tensor(f_bc_test, dtype=torch.float).to(device)
    return dict(
        f_train=f_train_tensor, f_bc_train=f_bc_train_tensor, u_train=u_train_tensor,
        f_test=f_test_tensor, f_bc_test=f_bc_test_tensor, u_test=u_test_tensor, x_network=x_tensor,
        u_train_np=u_train, u_test_np=u_test
    )

def get_label_bins(labels, n_bins=4):
    norms = np.linalg.norm(labels, axis=1)
    percentiles = np.percentile(norms, np.linspace(0, 100, n_bins+1))
    bins = np.digitize(norms, percentiles[1:-1], right=False)
    return bins, percentiles

def experiment_train_collect(
    model, optimizer, f, f_bc, x, u, epochs, batch_size=128, weighted_loss=True, 
    device='cuda', sample_history_epochs=(0,499,999,2999,4999,9999,19999,29999,39999,49999)):
    num_samples = f.shape[0]
    model = model.to(device)
    model.train()
    all_sample_losses = {ep: np.zeros(num_samples) for ep in sample_history_epochs}
    for epoch in tqdm(range(epochs)):
        optimizer.zero_grad()
        if weighted_loss:
            loss, loss_per = model.weighted_loss(f, f_bc, x, u)
        else:
            loss, loss_per = model.loss(f, f_bc, x, u)
        loss.backward()
        optimizer.step()
        if epoch in sample_history_epochs:
            with torch.no_grad():
                if weighted_loss:
                    _, losses_sample = model.weighted_loss(f, f_bc, x, u)
                else:
                    _, losses_sample = model.loss(f, f_bc, x, u)
            loss_per_sample = to_numpy(losses_sample.mean(dim=1))
            all_sample_losses[epoch] = loss_per_sample.copy()
    return all_sample_losses

def run_experiment():
    outdir = "./draw_data"
    os.makedirs(outdir, exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.manual_seed(42)
    np.random.seed(42)
    PODMode = 10
    num_bc = 68
    dim_br1 = [PODMode*2, 96, 96, 72]
    dim_br2 = [num_bc, 120, 150, 96, 72]
    dim_tr = [2, 48, 72, 72]
    num_train = 3300
    num_test = 200
    epochs = 50000
    batch_size = 128
    data = get_datasets(
        PODMode=PODMode, num_bc=num_bc, num_train=num_train, num_test=num_test, device=device
    )
    f_train, f_bc_train, u_train, x_all = data['f_train'], data['f_bc_train'], data['u_train'], data['x_network']
    bins, percentiles = get_label_bins(data['u_train_np'], n_bins=4)
    model_ours = OursOpNN(dim_br1, dim_br2, dim_tr).float().to(device)
    optimizer_ours = torch.optim.Adam(model_ours.parameters(), lr=0.001)
    sample_ep_list = [0,499,999,2999,4999,9999,19999,29999,39999,49999]
    our_sample_losses = experiment_train_collect(
        model_ours, optimizer_ours, f_train, f_bc_train, x_all, u_train, epochs=epochs, batch_size=batch_size,
        weighted_loss=True, device=device, sample_history_epochs=sample_ep_list)
    model_base = DIMONOpNN(dim_br1, dim_br2, dim_tr).float().to(device)
    optimizer_base = torch.optim.Adam(model_base.parameters(), lr=0.001)
    base_sample_losses = experiment_train_collect(
        model_base, optimizer_base, f_train, f_bc_train, x_all, u_train, epochs=epochs, batch_size=batch_size,
        weighted_loss=False, device=device, sample_history_epochs=sample_ep_list)
    for method, loss_hist in zip(["ours", "baseline"], [our_sample_losses, base_sample_losses]):
        all_records = []
        for epoch, losses in loss_hist.items():
            for idx, lossval in enumerate(losses):
                all_records.append({
                    'sample_idx': idx,
                    'epoch': int(epoch),
                    'loss_val': float(lossval),
                    'label_group': int(bins[idx])
                })
        out_fn = os.path.join(outdir, f'sample_loss_history_{method}.csv')
        import pandas as pd
        pd.DataFrame(all_records).to_csv(out_fn, index=False)
    bininfo = {f"bin_{i}": [float(percentiles[i]), float(percentiles[i+1])] for i in range(len(percentiles)-1)}
    np.savez(os.path.join(outdir, "label_group_bin_thresholds.npz"), **bininfo)
    for method, loss_hist in zip(["ours", "baseline"], [our_sample_losses, base_sample_losses]):
        rows = []
        for epoch, losses in loss_hist.items():
            losses = np.array(losses)
            for binid in range(4):
                group_losses = losses[bins == binid]
                rows.append({
                    'epoch': int(epoch),
                    'label_group': int(binid),
                    'mean_loss': float(np.mean(group_losses)),
                    'std_loss': float(np.std(group_losses)),
                    'var_loss': float(np.var(group_losses)),
                    'count': int(np.sum(bins == binid))
                })
        out_fn = os.path.join(outdir, f'loss_grouped_stats_{method}.csv')
        pd.DataFrame(rows).to_csv(out_fn, index=False)
    boxplot_outs = {}
    for method, loss_hist in zip(["ours", "baseline"], [our_sample_losses, base_sample_losses]):
        final_epoch = max(loss_hist)
        losses = loss_hist[final_epoch]
        boxplot_outs[method] = [losses[bins == i] for i in range(4)]
    np.savez(os.path.join(outdir, "boxplot_grouped_final_epoch_losses.npz"), **boxplot_outs)

if __name__ == "__main__":
    run_experiment()
