import os
import sys
os.environ['CUDA_VISIBLE_DEVICES']='3'
import torch
import torch.nn as nn
import numpy as np
import scipy.io as io
from sklearn.decomposition import PCA
from tqdm import tqdm

class OPNN(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super(OPNN, self).__init__()
        self.z_dim = trunk_dim[-1]
        self._branch1 = self._build_branch(branch1_dim, True)
        self._branch2 = self._build_branch(branch2_dim, True)
        self._trunk = self._build_branch(trunk_dim, True)

    def _build_branch(self, dims, residual):
        modules = []
        in_channels = dims[0]
        for i, h_dim in enumerate(dims[1:]):
            modules.append(nn.Sequential(
                nn.Linear(in_channels, h_dim),
                nn.Tanh()
            ))
            if residual and i > 0 and h_dim == dims[i]:
                modules.append(nn.Linear(h_dim, h_dim))
            in_channels = h_dim
        return nn.Sequential(*modules)

    def forward(self, f, f_bc, x):
        y_br1 = self._branch1(f)
        y_br2 = self._branch2(f_bc)
        y_br = y_br1 * y_br2
        y_tr = self._trunk(x)
        return torch.einsum('ij,kj->ik', y_br, y_tr)

    def loss(self, f, f_bc, x, y):
        y_out = self.forward(f, f_bc, x)
        weights = torch.abs(y) + 1.0
        return (weights * (y_out - y)**2).mean()

class DIMON_OPNN(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super(DIMON_OPNN, self).__init__()
        self.z_dim = trunk_dim[-1]
        self._branch1 = self._build_branch(branch1_dim, False)
        self._branch2 = self._build_branch(branch2_dim, False)
        self._trunk = self._build_branch(trunk_dim, False)

    def _build_branch(self, dims, residual):
        modules = []
        in_channels = dims[0]
        for h_dim in dims[1:]:
            modules.append(nn.Sequential(
                nn.Linear(in_channels, h_dim),
                nn.Tanh()
            ))
            in_channels = h_dim
        return nn.Sequential(*modules)

    def forward(self, f, f_bc, x):
        y_br1 = self._branch1(f)
        y_br2 = self._branch2(f_bc)
        y_br = y_br1 * y_br2
        y_tr = self._trunk(x)
        return torch.einsum('ij,kj->ik', y_br, y_tr)

    def loss(self, f, f_bc, x, y):
        y_out = self.forward(f, f_bc, x)
        return ((y_out - y)**2).mean()

def run_experiment():
    # Load dataset
    dataset = io.loadmat('Laplace/dataset/Laplace_data.mat')
    x = dataset['x_uni']
    x_mesh = dataset['x_mesh_data']
    dx = x_mesh - x
    u = dataset['u_data']
    u_bc = dataset['u_bc'][:, ::3]

    PODMode = 10
    pca_x = PCA(n_components=PODMode)
    pca_y = PCA(n_components=PODMode)
    coeff_x_train = pca_x.fit_transform(dx[:3300, :, 0] - dx[:3300, :, 0].mean(axis=0))
    coeff_y_train = pca_y.fit_transform(dx[:3300, :, 1] - dx[:3300, :, 1].mean(axis=0))
    f_train = np.concatenate((coeff_x_train, coeff_y_train), axis=1)

    dim_br1 = [PODMode*2, 100, 100, 100]
    dim_br2 = [68, 150, 150, 150, 100]
    dim_tr = [2, 100, 100, 100]

    model_opnn = OPNN(dim_br1, dim_br2, dim_tr).to('cuda')
    model_dimon = DIMON_OPNN(dim_br1, dim_br2, dim_tr).to('cuda')

    for model, name in zip([model_opnn, model_dimon], ['OPNN', 'DIMON']):
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        for epoch in tqdm(range(50000), desc=f'Training {name}'):
            optimizer.zero_grad()
            loss = model.loss(
                torch.tensor(f_train, dtype=torch.float).to('cuda'),
                torch.tensor(u_bc[:3300], dtype=torch.float).to('cuda'),
                torch.tensor(x, dtype=torch.float).to('cuda'),
                torch.tensor(u[:3300], dtype=torch.float).to('cuda')
            )
            loss.backward()
            optimizer.step()

    # Save results
    np.save('./draw_data/opnn_results.npy',
            {'loss': loss.item()})
    np.save('//draw_data/dimon_results.npy',
            {'loss': loss.item()})

if __name__ == '__main__':
    run_experiment()